diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py index 1fe029f9fade..64cf6114ceba 100644 --- a/nemo/deploy/nlp/megatronllm_deployable.py +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -33,6 +33,16 @@ from nemo.deploy import ITritonDeployable from nemo.deploy.utils import cast_output, str_ndarray2list +try: + from megatron.core.dist_checkpointing.validation import StrictHandling + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError) as e: + + HAVE_MEGATRON_CORE = False + IMPORT_ERROR = e + @wrapt.decorator def noop_decorator(func): @@ -99,6 +109,8 @@ def __init__( num_nodes: int = 1, existing_model: MegatronGPTModel = None, ): + if not HAVE_MEGATRON_CORE: + raise IMPORT_ERROR if nemo_checkpoint_filepath is None and existing_model is None: raise ValueError( "MegatronLLMDeployable requires either a .nemo checkpoint filepath or an existing MegatronGPTModel, but both provided were None" @@ -142,6 +154,14 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices: # had to override these to make Nemotron3-22B work, see sample_sequence_batch() in text_generation_utils.py custom_config.activations_checkpoint_granularity = None custom_config.activations_checkpoint_method = None + # Models trained with TE < 1.10 and loaded with TE >= 1.10 require + # special handling on loading checkpoint due to structural updates + custom_config.dist_ckpt_load_strictness = StrictHandling.LOG_ALL.value + if custom_config.get("fp8", False): + # Need to disable FP8 for in-framework inference due to shape constraints imposed by TE, + # see https://github.com/NVIDIA/TransformerEngine/blob/v1.10/transformer_engine/pytorch/utils.py#L229 + LOGGER.warning("Disabling FP8 inference due to shape constraints imposed by Transformer Engine.") + custom_config.fp8 = False self.model = MegatronGPTModel.restore_from( nemo_checkpoint_filepath, trainer=trainer, override_config_path=custom_config