Skip to content

Commit

Permalink
[NeMo-UX] Integrate tokenizer import into model.import_ckpt (#9485)
Browse files Browse the repository at this point in the history
* Integrate tokenizer import into model.import_ckpt

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

* Fixing bug in ModelConnector.nemo_save

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

* Default to ddp=pytorch inside ModelConnector

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

---------

Signed-off-by: marcromeyn <[email protected]>
Co-authored-by: marcromeyn <[email protected]>
  • Loading branch information
marcromeyn and marcromeyn authored Jun 17, 2024
1 parent d13e532 commit bfd07b9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 139 deletions.
8 changes: 5 additions & 3 deletions nemo/collections/llm/gpt/model/mistral_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Optional

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from typing_extensions import Annotated
Expand Down Expand Up @@ -46,9 +47,7 @@ def __init__(
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
):
_tokenizer = tokenizer or HFMistral7BImporter("mistralai/Mistral-7B-v0.1").tokenizer

super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=_tokenizer)
super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=tokenizer)


@io.model_importer(Mistral7BModel, "hf")
Expand All @@ -72,6 +71,9 @@ def apply(self, output_path: Path) -> Path:

return output_path

def on_import_ckpt(self, model: pl.LightningModule):
model.tokenizer = self.tokenizer

def convert_state(self, source, target):
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
Expand Down
122 changes: 0 additions & 122 deletions nemo/lightning/experiment.py

This file was deleted.

16 changes: 13 additions & 3 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import os
import shutil
Expand Down Expand Up @@ -138,7 +139,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] =
from nemo.lightning import MegatronStrategy, Trainer

_trainer = trainer or Trainer(
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False)
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False, ddp="pytorch")
)

_trainer.strategy.connect(model)
Expand All @@ -159,7 +160,12 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
output_path (Path): The path where the model checkpoint will be saved.
trainer (pl.Trainer): The trainer with the strategy to save the model.
"""
trainer.strategy.setup(trainer)
_setup_kwargs = {}
setup_signature = inspect.signature(trainer.strategy.setup)
if 'setup_optimizers' in setup_signature.parameters:
_setup_kwargs["setup_optimizers"] = False

trainer.strategy.setup(trainer, **_setup_kwargs)
trainer.save_checkpoint(output_path)

def nemo_load(
Expand All @@ -181,7 +187,9 @@ def nemo_load(
from nemo.lightning.io.api import load_ckpt

model = load_ckpt(path).model
_trainer = trainer or Trainer(devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy())
_trainer = trainer or Trainer(
devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy(ddp="pytorch")
)

_trainer.strategy.connect(model)
_trainer.strategy.setup_environment()
Expand All @@ -208,3 +216,5 @@ def local_path(self, base_path: Optional[Path] = None) -> Path:
_base = Path(NEMO_MODELS_CACHE)

return _base / str(self).replace("://", "/")

def on_import_ckpt(self, model: pl.LightningModule): ...
2 changes: 2 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def import_ckpt(self, path: str, overwrite: bool = False, base_path: Optional[Pa
ckpt_path: Path = connector.local_path(base_path=base_path)
ckpt_path = connector(ckpt_path, overwrite=overwrite)

connector.on_import_ckpt(self)

return ckpt_path

@classmethod
Expand Down
22 changes: 11 additions & 11 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def connect(self, model: pl.LightningModule) -> None:
self._mcore_config = config

@override
def setup(self, trainer: pl.Trainer) -> None:
def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
self.trainer = trainer
Expand All @@ -150,7 +150,7 @@ def setup(self, trainer: pl.Trainer) -> None:
self.data_sampler.connect(trainer)

self._fix_progress_bar(trainer)
self.setup_megatron_parallel(trainer)
self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers)
self.setup_precision_plugin()

if trainer.num_sanity_val_steps > 1 and self.pipeline_model_parallel_size > 1:
Expand Down Expand Up @@ -205,7 +205,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:

return dataloader

def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
assert self.model is not None, "Model is not set"

self.megatron_parallel = MegatronParallel(
Expand All @@ -224,16 +224,16 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
self.model.configure_optimizers, megatron_parallel=self.megatron_parallel
)

self.setup_optimizers(trainer)
if setup_optimizers:
self.setup_optimizers(trainer)

# TODO: Throw an execption if we have a mcore optimizer and no ddp_config
# TODO: Throw an execption if we have a mcore optimizer and no ddp_config
if hasattr(self.precision_plugin, "convert_optimizer"):
_optimizers = [*self.optimizers]
_optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0])
self.optimizers = _optimizers

if hasattr(self.precision_plugin, "convert_optimizer"):
_optimizers = [*self.optimizers]
_optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0])
self.optimizers = _optimizers

_optimizers_to_device(self.optimizers, self.root_device)
_optimizers_to_device(self.optimizers, self.root_device)

self.model = self.megatron_parallel

Expand Down

0 comments on commit bfd07b9

Please sign in to comment.