-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
In-framework inference fixes #10698
Merged
Merged
In-framework inference fixes #10698
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
5d6ccf7
Fix loading legacy checkpoints
janekl 01e680d
Fix inference issues FP8-trained models
janekl 1938b38
Apply isort and black reformatting
janekl 87bcae6
Comment on TE shape contraints during inference
janekl 8e1e1c8
Simplify import error handling
janekl b29975f
Comment on issues
janekl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,16 @@ | |
from nemo.deploy import ITritonDeployable | ||
from nemo.deploy.utils import cast_output, str_ndarray2list | ||
|
||
try: | ||
from megatron.core.dist_checkpointing.validation import StrictHandling | ||
|
||
HAVE_MEGATRON_CORE = True | ||
|
||
except (ImportError, ModuleNotFoundError) as e: | ||
|
||
HAVE_MEGATRON_CORE = False | ||
IMPORT_ERROR = e | ||
|
||
|
||
@wrapt.decorator | ||
def noop_decorator(func): | ||
|
@@ -99,6 +109,8 @@ def __init__( | |
num_nodes: int = 1, | ||
existing_model: MegatronGPTModel = None, | ||
): | ||
if not HAVE_MEGATRON_CORE: | ||
raise IMPORT_ERROR | ||
if nemo_checkpoint_filepath is None and existing_model is None: | ||
raise ValueError( | ||
"MegatronLLMDeployable requires either a .nemo checkpoint filepath or an existing MegatronGPTModel, but both provided were None" | ||
|
@@ -142,6 +154,14 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices: | |
# had to override these to make Nemotron3-22B work, see sample_sequence_batch() in text_generation_utils.py | ||
custom_config.activations_checkpoint_granularity = None | ||
custom_config.activations_checkpoint_method = None | ||
# Models trained with TE < 1.10 and loaded with TE >= 1.10 require | ||
# special handling on loading checkpoint due to structural updates | ||
custom_config.dist_ckpt_load_strictness = StrictHandling.LOG_ALL.value | ||
if custom_config.get("fp8", False): | ||
# Need to disable FP8 for in-framework inference due to shape constraints imposed by TE, | ||
# see https://github.com/NVIDIA/TransformerEngine/blob/v1.10/transformer_engine/pytorch/utils.py#L229 | ||
LOGGER.warning("Disabling FP8 inference due to shape constraints imposed by Transformer Engine.") | ||
custom_config.fp8 = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in-framework FP8 inference is not supported |
||
|
||
self.model = MegatronGPTModel.restore_from( | ||
nemo_checkpoint_filepath, trainer=trainer, override_config_path=custom_config | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomasdhc @mikolajblaz @dimapihtar do you have any idea why the 1st error mentioned in MR description is visible in r2.0.0 branch but on the other hand main looks good with the same repro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed with @mikolajblaz offline, this is likely due to different TE versions used in two different containers tested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct. If you don't want to import MCore you can set a string
'log_all'
hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to import it as it's required here anyway. This is also more transparent to me besides on what's going on.