-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
only enable query key scaling during fp16 #7946
Conversation
Signed-off-by: Gerald Shen <[email protected]>
jenkins |
Signed-off-by: Gerald Shen <[email protected]>
jenkins |
@@ -1544,6 +1544,11 @@ def build_transformer_config(self) -> TransformerConfig: | |||
|
|||
attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True | |||
apply_query_key_layer_scaling = self.cfg.get('apply_query_key_layer_scaling', False) | |||
|
|||
if apply_query_key_layer_scaling and not model_parallel_config.fp16: | |||
logging.warning("apply_query_key_layer_scaling is only enabled when using FP16, setting it to False") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think model_parallel_config.fp16 is the right check though. That arg is for fp16 + megatron_amp_O2.
Maybe we should just check trainer.precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can I check against self.torch_dtype
?
self.torch_dtype = utils_funcs.torch_dtype_from_precision(self.cfg.precision) # Mixed precision datatype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed 25b7349
Signed-off-by: Gerald Shen <[email protected]>
jenkins |
Signed-off-by: Gerald Shen <[email protected]>
jenkins |
Signed-off-by: Gerald Shen <[email protected]>
jenkins |
if fp16_enabled: | ||
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1" | ||
else: | ||
logging.warning("apply_query_key_layer_scaling is only enabled when using FP16, setting it to False") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we set the env var here as well to "0" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, otherwise it will error 366cf61
Signed-off-by: Gerald Shen <[email protected]>
jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
* only enable query key scaling during fp16 Signed-off-by: Gerald Shen <[email protected]> * add warning Signed-off-by: Gerald Shen <[email protected]> * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen <[email protected]> * remove var from jenkens file Signed-off-by: Gerald Shen <[email protected]> * fix test by setting TE var Signed-off-by: Gerald Shen <[email protected]> * set to 0 if disabled Signed-off-by: Gerald Shen <[email protected]> --------- Signed-off-by: Gerald Shen <[email protected]> Signed-off-by: Piotr Żelasko <[email protected]>
* only enable query key scaling during fp16 Signed-off-by: Gerald Shen <[email protected]> * add warning Signed-off-by: Gerald Shen <[email protected]> * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen <[email protected]> * remove var from jenkens file Signed-off-by: Gerald Shen <[email protected]> * fix test by setting TE var Signed-off-by: Gerald Shen <[email protected]> * set to 0 if disabled Signed-off-by: Gerald Shen <[email protected]> --------- Signed-off-by: Gerald Shen <[email protected]> Signed-off-by: Sasha Meister <[email protected]>
* only enable query key scaling during fp16 Signed-off-by: Gerald Shen <[email protected]> * add warning Signed-off-by: Gerald Shen <[email protected]> * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen <[email protected]> * remove var from jenkens file Signed-off-by: Gerald Shen <[email protected]> * fix test by setting TE var Signed-off-by: Gerald Shen <[email protected]> * set to 0 if disabled Signed-off-by: Gerald Shen <[email protected]> --------- Signed-off-by: Gerald Shen <[email protected]>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Enable query key scaling only during fp16, since this is when TE enables query key scaling. https://github.com/NVIDIA/TransformerEngine/blob/666539f36275fa9c0fbc99f9ea50f2d6e29e336f/transformer_engine/pytorch/attention.py#L940