Skip to content

Commit

Permalink
Fix flash-attention (#6901)
Browse files Browse the repository at this point in the history
* Set default apply_query_key_layer_scaling to false

Signed-off-by: hsiehjackson <[email protected]>

* Add cross attention test

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: hsiehjackson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hsiehjackson and pre-commit-ci[bot] committed Jun 26, 2023
1 parent e736c86 commit 8204483
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ model:
attention_dropout: 0.1 # Dropout probability for attention
ffn_dropout: 0.0 # Dropout probability in the feed-forward layer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
apply_query_key_layer_scaling: False # scale Q * K^T by 1 / layer-number.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
layernorm_epsilon: 1e-5
do_layer_norm_weight_decay: False # True means weight decay on all params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
num_layers,
num_attention_heads,
ffn_hidden_size,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
num_tokentypes=0,
parallel_output=True,
Expand Down
34 changes: 11 additions & 23 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
precision=16,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
use_cpu_initialization=False,
megatron_amp_O2=False,
Expand Down Expand Up @@ -564,7 +564,7 @@ def __init__(
num_attention_heads,
hidden_size,
precision=16,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
use_cpu_initialization=False,
megatron_amp_O2=False,
Expand Down Expand Up @@ -728,7 +728,7 @@ def __init__(
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
precision=16,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
masked_softmax_fusion=True,
attention_dropout=0.1,
Expand Down Expand Up @@ -928,7 +928,6 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a
attention_scores += attention_bias

attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.

Expand Down Expand Up @@ -966,15 +965,6 @@ def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, a
else:
return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask,)

def reset_is_causal(self, query_length, key_length, causal):
if query_length != key_length:
if query_length == 1:
return False
raise NotImplementedError(
"Flash attention does not support query and key with different number of tokens, unless number of query tokens is 1."
)
return causal

def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask):
batch_size, seqlen, nheads, _ = query_layer.shape

Expand All @@ -994,9 +984,7 @@ def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_ma
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_layer, attention_mask_q)
k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_layer, attention_mask_kv)
v, _, _, _ = unpad_input(value_layer, attention_mask_kv)
causal = self.reset_is_causal(
query_layer.shape[1], key_layer.shape[1], self.attn_mask_type == AttnMaskType.causal
)
is_causal = self.attn_mask_type == AttnMaskType.causal and query_layer.shape[1] == key_layer.shape[1]
context_layer = flash_attn_unpadded_func(
q,
k,
Expand All @@ -1006,7 +994,7 @@ def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_ma
max_seqlen_q,
max_seqlen_k,
dropout_p=self.attention_dropout_p if self.training else 0.0,
causal=causal,
causal=is_causal,
)

# [b, sq, np, hn]
Expand All @@ -1031,13 +1019,13 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_
attention_mask_q = attention_mask.unsqueeze(1).unsqueeze(3)
attention_mask_kv = attention_mask.unsqueeze(1).unsqueeze(2)

attention_bias = attention_bias.masked_fill(~attention_mask_q, torch.finfo(query_layer.dtype).min)
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)
if attention_bias.shape[2] == attention_mask_q.shape[2]:
attention_bias = attention_bias.masked_fill(~attention_mask_q, torch.finfo(query_layer.dtype).min)
if attention_bias.shape[3] == attention_mask_kv.shape[3]:
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)

causal = self.reset_is_causal(
query_layer.shape[1], key_layer.shape[1], self.attn_mask_type == AttnMaskType.causal
)
context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_bias, causal)
is_causal = self.attn_mask_type == AttnMaskType.causal and query_layer.shape[1] == key_layer.shape[1]
context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_bias, is_causal,)

# [b, sq, np, hn] -> [b, np, sq, hn]
context_layer = context_layer.permute(0, 2, 1, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward_torch_softmax(self, input, mask):
probs = torch.nn.Softmax(dim=-1)(mask_output)
if mask is not None:
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
zero_attention_mask = (1.0 - all_k_masked.type(probs.type()))[:, :, :, None]
probs = probs * zero_attention_mask

if self.input_in_float16 and self.softmax_in_fp32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_language_model(
vocab_size,
num_attention_heads,
encoder_attn_mask_type,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
init_method=None,
scaled_init_method=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_decoder_model(
ffn_hidden_size,
num_layers,
num_attention_heads,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
init_method=None,
scaled_init_method=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_encoder_model(
ffn_hidden_size,
num_layers,
num_attention_heads,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
init_method=None,
scaled_init_method=None,
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self_attn_mask_type=AttnMaskType.padding,
fp32_residual_connection=False,
precision=16,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
layernorm_epsilon=1e-5,
hidden_dropout=0.1,
Expand Down Expand Up @@ -659,7 +659,7 @@ def __init__(
self_attn_mask_type=AttnMaskType.padding,
fp32_residual_connection=False,
precision=16,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
layernorm_epsilon=1e-5,
hidden_dropout=0.1,
Expand Down Expand Up @@ -804,7 +804,7 @@ def __init__(
params_dtype: torch.dtype = torch.float32,
get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = True,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -895,7 +895,7 @@ def __init__(
hidden_size,
ffn_hidden_size,
num_attention_heads,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
kv_channels=None,
layer_type=LayerType.encoder, # it can be a list of types or single type
self_attn_mask_type=AttnMaskType.padding,
Expand Down
Loading

0 comments on commit 8204483

Please sign in to comment.