Skip to content

Commit

Permalink
support attention mask bias in encocder
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Feb 19, 2024
1 parent 48c83e6 commit 803416e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 60 deletions.
118 changes: 62 additions & 56 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,30 @@
from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask


@pytest.mark.parametrize("args", [
{
"n_feat": 256,
"n_head": 4,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 8,
"dropout_rate": 0.0
},
{
"n_feat": 1280,
"n_head": 20,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 4,
"dropout_rate": 0.0
},
])
@pytest.mark.parametrize(
"args",
[
{
"n_feat": 256,
"n_head": 4,
"dropout_rate": 0.0
},
# {
# "n_feat": 512,
# "n_head": 8,
# "dropout_rate": 0.0
# },
# {
# "n_feat": 1280,
# "n_head": 20,
# "dropout_rate": 0.0
# },
# {
# "n_feat": 512,
# "n_head": 4,
# "dropout_rate": 0.0
# },
])
def test_sdpa(args):
torch.manual_seed(777)
mha_module = MultiHeadedAttention(use_sdpa=False, **args)
Expand All @@ -59,52 +61,56 @@ def test_sdpa(args):
att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min
output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa(
q, k, v, mask=att_mask_bias)

assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)

n_blocks = 12
torch.manual_seed(777)
mha_layer = TransformerEncoderLayer(
args['n_feat'],
mha_module,
PositionwiseFeedForward(
mha_layers = [
TransformerEncoderLayer(
args['n_feat'],
2048,
MultiHeadedAttention(use_sdpa=False, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
)
normalize_before=True,
) for _ in range(n_blocks)
]

torch.manual_seed(777)
mha_layer_with_sdpa = TransformerEncoderLayer(
args['n_feat'],
mha_module_with_sdpa,
PositionwiseFeedForward(
mha_layers_with_sdpa = [
TransformerEncoderLayer(
args['n_feat'],
2048,
MultiHeadedAttention(use_sdpa=True, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
)
mha_layer.eval()
mha_layer_with_sdpa.eval()
output, _, cache, _ = mha_layer(q, att_mask, None, mask)
output_with_sdpa, _, cache_with_sdpa, _ = mha_layer_with_sdpa(
q, att_mask_bias, None, mask)
normalize_before=True,
) for _ in range(n_blocks)
]

print(output)
print(output_with_sdpa)
assert torch.allclose(
output,
output_with_sdpa,
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)
for i in range(n_blocks):
output, _, cache, _ = mha_layers[i](q, att_mask, None, mask)
output_with_sdpa, _, cache_with_sdpa, _ = mha_layers_with_sdpa[i](
q, att_mask_bias, None, mask)

assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)

q = output
4 changes: 0 additions & 4 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def forward(
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
if mask_pad.size(2) > 0:
x_att = x_att * mask_pad.transpose(1, 2)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
Expand Down Expand Up @@ -206,8 +204,6 @@ def forward(
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
att_cache)
if mask_pad.size(2) > 0:
x_att = x_att * mask_pad.transpose(1, 2)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
Expand Down

0 comments on commit 803416e

Please sign in to comment.