Skip to content

Commit

Permalink
Fix loading legacy checkpoints
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl committed Oct 1, 2024
1 parent 32503fd commit 86408cc
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions nemo/deploy/nlp/megatronllm_deployable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@
from nemo.deploy import ITritonDeployable
from nemo.deploy.utils import cast_output, str_ndarray2list

try:
from megatron.core.dist_checkpointing.validation import StrictHandling

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError) as e:

HAVE_MEGATRON_CORE = False
IMPORT_ERROR = (
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-core."
f" Exact error: {e}"
)

@wrapt.decorator
def noop_decorator(func):
Expand Down Expand Up @@ -99,6 +111,8 @@ def __init__(
num_nodes: int = 1,
existing_model: MegatronGPTModel = None,
):
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR)
if nemo_checkpoint_filepath is None and existing_model is None:
raise ValueError(
"MegatronLLMDeployable requires either a .nemo checkpoint filepath or an existing MegatronGPTModel, but both provided were None"
Expand Down Expand Up @@ -142,6 +156,7 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices:
# had to override these to make Nemotron3-22B work, see sample_sequence_batch() in text_generation_utils.py
custom_config.activations_checkpoint_granularity = None
custom_config.activations_checkpoint_method = None
custom_config.dist_ckpt_load_strictness = StrictHandling.LOG_ALL.value

self.model = MegatronGPTModel.restore_from(
nemo_checkpoint_filepath, trainer=trainer, override_config_path=custom_config
Expand Down

0 comments on commit 86408cc

Please sign in to comment.