Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and hsiehjackson committed Jun 21, 2023
1 parent 8303f08 commit 576ff68
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
12 changes: 3 additions & 9 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,20 +1018,14 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_
assert len(attention_mask.shape) == 2
attention_mask_q = attention_mask.unsqueeze(1).unsqueeze(3)
attention_mask_kv = attention_mask.unsqueeze(1).unsqueeze(2)

if attention_bias.shape[2] == attention_mask_q.shape[2]:
attention_bias = attention_bias.masked_fill(~attention_mask_q, torch.finfo(query_layer.dtype).min)
if attention_bias.shape[3] == attention_mask_kv.shape[3]:
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)

is_causal = self.attn_mask_type == AttnMaskType.causal and query_layer.shape[1] == key_layer.shape[1]
context_layer = flash_attn_func(
query_layer,
key_layer,
value_layer,
attention_bias,
is_causal,
)
context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_bias, is_causal,)

# [b, sq, np, hn] -> [b, np, sq, hn]
context_layer = context_layer.permute(0, 2, 1, 3)
Expand Down
16 changes: 8 additions & 8 deletions tests/collections/nlp/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_flash_self_attention(self, cfg):
hidden_size=h,
attn_mask_type=AttnMaskType.causal,
attention_dropout=0.0,
apply_query_key_layer_scaling=False
apply_query_key_layer_scaling=False,
)

attention_fa = CoreAttention(
Expand All @@ -171,7 +171,7 @@ def test_flash_self_attention(self, cfg):
torch.testing.assert_close(out, out_fa)
out_fa = attention_fa(q, k, v, attention_mask_2d)
torch.testing.assert_close(out, out_fa)

@pytest.mark.skipif(not HAVE_FA, reason="flash-attention is not installed")
@pytest.mark.unit
def test_flash_cross_attention(self, cfg):
Expand All @@ -187,7 +187,7 @@ def test_flash_cross_attention(self, cfg):
attention_mask_2d_q = torch.arange(sq, device=device).unsqueeze(0) < torch.randint(
1, sq, (bz,), device=device
).unsqueeze(1)

attention_mask_2d_k = torch.arange(sk, device=device).unsqueeze(0) < torch.randint(
1, sk, (bz,), device=device
).unsqueeze(1)
Expand All @@ -202,7 +202,7 @@ def test_flash_cross_attention(self, cfg):
hidden_size=h,
attn_mask_type=AttnMaskType.padding,
attention_dropout=0.0,
apply_query_key_layer_scaling=False
apply_query_key_layer_scaling=False,
)

attention_fa = CoreAttention(
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_flash_self_attention_triton(self, cfg):
hidden_size=h,
attn_mask_type=AttnMaskType.padding,
attention_dropout=0.0,
apply_query_key_layer_scaling=False
apply_query_key_layer_scaling=False,
)

attention_fa = CoreAttention(
Expand All @@ -281,7 +281,7 @@ def test_flash_self_attention_triton(self, cfg):
hidden_size=h,
attn_mask_type=AttnMaskType.causal,
attention_dropout=0.0,
apply_query_key_layer_scaling=False
apply_query_key_layer_scaling=False,
)

attention_fa = CoreAttention(
Expand Down Expand Up @@ -319,7 +319,7 @@ def test_flash_cross_attention_triton(self, cfg):
attention_mask_2d_q = torch.arange(sq, device=device).unsqueeze(0) < torch.randint(
1, sq, (bz,), device=device
).unsqueeze(1)

attention_mask_2d_k = torch.arange(sk, device=device).unsqueeze(0) < torch.randint(
1, sk, (bz,), device=device
).unsqueeze(1)
Expand All @@ -336,7 +336,7 @@ def test_flash_cross_attention_triton(self, cfg):
hidden_size=h,
attn_mask_type=AttnMaskType.padding,
attention_dropout=0.0,
apply_query_key_layer_scaling=False
apply_query_key_layer_scaling=False,
)

attention_fa = CoreAttention(
Expand Down

0 comments on commit 576ff68

Please sign in to comment.