Skip to content

Commit

Permalink
only enable query key scaling during fp16 (NVIDIA#7946)
Browse files Browse the repository at this point in the history
* only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* add warning

Signed-off-by: Gerald Shen <[email protected]>

* fixup! only enable query key scaling during fp16

Signed-off-by: Gerald Shen <[email protected]>

* remove var from jenkens file

Signed-off-by: Gerald Shen <[email protected]>

* fix test by setting TE var

Signed-off-by: Gerald Shen <[email protected]>

* set to 0 if disabled

Signed-off-by: Gerald Shen <[email protected]>

---------

Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
gshennvm committed Dec 1, 2023
1 parent 125f874 commit 6c65d9f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
3 changes: 0 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ pipeline {
timeout(time: 8, unit: 'HOURS')
disableConcurrentBuilds(abortPrevious: true)
}
environment {
NVTE_APPLY_QK_LAYER_SCALING = 1
}

stages {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,19 @@ def build_transformer_config(self) -> TransformerConfig:

attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True
apply_query_key_layer_scaling = self.cfg.get('apply_query_key_layer_scaling', False)

fp16_enabled = self.trainer.precision in [16, '16', '16-mixed']
if apply_query_key_layer_scaling:
if fp16_enabled:
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1"
else:
logging.warning(
"apply_query_key_layer_scaling is only enabled when using FP16, setting it to False "
"and setting NVTE_APPLY_QK_LAYER_SCALING=0"
)
os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "0"
apply_query_key_layer_scaling = False

if apply_query_key_layer_scaling:
attention_softmax_in_fp32 = True

Expand All @@ -1570,6 +1583,7 @@ def build_transformer_config(self) -> TransformerConfig:

# any configs that are not in the nemo model config will be added here
config_mapping = {
'apply_query_key_layer_scaling': apply_query_key_layer_scaling,
'apply_residual_connection_post_layernorm': False, # we don't use this in NeMo
'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma,
'add_bias_linear': add_bias_linear,
Expand Down

0 comments on commit 6c65d9f

Please sign in to comment.