Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only enable query key scaling during fp16 #7946

Merged
merged 7 commits into from
Dec 1, 2023

Conversation

gshennvm
Copy link
Collaborator

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Enable query key scaling only during fp16, since this is when TE enables query key scaling. https://github.com/NVIDIA/TransformerEngine/blob/666539f36275fa9c0fbc99f9ea50f2d6e29e336f/transformer_engine/pytorch/attention.py#L940

@ericharper
Copy link
Collaborator

jenkins

Signed-off-by: Gerald Shen <[email protected]>
@ericharper
Copy link
Collaborator

jenkins

@@ -1544,6 +1544,11 @@ def build_transformer_config(self) -> TransformerConfig:

attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True
apply_query_key_layer_scaling = self.cfg.get('apply_query_key_layer_scaling', False)

if apply_query_key_layer_scaling and not model_parallel_config.fp16:
logging.warning("apply_query_key_layer_scaling is only enabled when using FP16, setting it to False")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think model_parallel_config.fp16 is the right check though. That arg is for fp16 + megatron_amp_O2.

Maybe we should just check trainer.precision?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can I check against self.torch_dtype?

self.torch_dtype = utils_funcs.torch_dtype_from_precision(self.cfg.precision) # Mixed precision datatype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed 25b7349

@gshennvm
Copy link
Collaborator Author

jenkins

Signed-off-by: Gerald Shen <[email protected]>
@github-actions github-actions bot added the CI label Nov 30, 2023
@gshennvm
Copy link
Collaborator Author

jenkins

Signed-off-by: Gerald Shen <[email protected]>
@gshennvm
Copy link
Collaborator Author

gshennvm commented Dec 1, 2023

jenkins

if fp16_enabled:
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1"
else:
logging.warning("apply_query_key_layer_scaling is only enabled when using FP16, setting it to False")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set the env var here as well to "0" ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, otherwise it will error 366cf61

Signed-off-by: Gerald Shen <[email protected]>
@gshennvm
Copy link
Collaborator Author

gshennvm commented Dec 1, 2023

jenkins

Copy link
Collaborator

@ericharper ericharper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@gshennvm gshennvm merged commit a7f0bc1 into main Dec 1, 2023
15 checks passed
@gshennvm gshennvm deleted the geshen/fix_query_key_layer_scaling branch December 1, 2023 20:30
pzelasko pushed a commit to pzelasko/NeMo that referenced this pull request Jan 3, 2024
* only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* add warning

Signed-off-by: Gerald Shen <[email protected]>

* fixup! only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* remove var from jenkens file

Signed-off-by: Gerald Shen <[email protected]>

* fix test by setting TE var

Signed-off-by: Gerald Shen <[email protected]>

* set to 0 if disabled

Signed-off-by: Gerald Shen <[email protected]>

---------

Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>
ssh-meister pushed a commit to ssh-meister/NeMo that referenced this pull request Feb 15, 2024
* only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* add warning

Signed-off-by: Gerald Shen <[email protected]>

* fixup! only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* remove var from jenkens file

Signed-off-by: Gerald Shen <[email protected]>

* fix test by setting TE var

Signed-off-by: Gerald Shen <[email protected]>

* set to 0 if disabled

Signed-off-by: Gerald Shen <[email protected]>

---------

Signed-off-by: Gerald Shen <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* add warning

Signed-off-by: Gerald Shen <[email protected]>

* fixup! only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* remove var from jenkens file

Signed-off-by: Gerald Shen <[email protected]>

* fix test by setting TE var

Signed-off-by: Gerald Shen <[email protected]>

* set to 0 if disabled

Signed-off-by: Gerald Shen <[email protected]>

---------

Signed-off-by: Gerald Shen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants