Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunGumma committed Jul 6, 2024
1 parent d1ed050 commit 68c2516
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
17 changes: 15 additions & 2 deletions fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_

from fairseq.modules.rms_norm import RMSNorm
from einops import rearrange

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -372,7 +373,11 @@ def extract_features_scriptable(
else None
)

if self.alibi is not None and self_attn_mask is None and incremental_state is not None:
if (
self.alibi is not None
and self_attn_mask is None
and incremental_state is not None
):
self_attn_mask = self._bias_attn_mask(x, incremental_state)

# B x T x C -> T x B x C
Expand Down Expand Up @@ -442,7 +447,15 @@ def _bias_attn_mask(self, x, incremental_state):

src_len = saved_state["prev_key"].shape[2]

return self.alibi[:, src_len, : src_len + 1].unsqueeze(1).to(x.device)
attn_mask = (
self.alibi[:, src_len, : src_len + 1]
.unsqueeze(1)
.unsqueeze(0)
.expand(x.size(0), -1, -1, -1)
.to(x.device)
)
attn_mask = rearrange(attn_mask, "b h t c -> (b h) t c")
return attn_mask

def buffered_future_mask(self, tensor):
B, T, _ = tensor.size()
Expand Down
9 changes: 1 addition & 8 deletions fairseq/modules/native_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,7 @@ def forward(
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

if attn_mask is not None:
if saved_state is not None:
# HACK: this happens only with ALiBi, when the attn_mask is the attn_bias
# in no other inference, this branch should be taken
attn_weights = rearrange(attn_weights, "(b h) t s -> b h t s", h=self.num_heads)
attn_weights += attn_mask
attn_weights = rearrange(attn_weights, "b h t s -> (b h) t s", h=self.num_heads)
else:
attn_weights += attn_mask
attn_weights += attn_mask

if key_padding_mask is not None:
# don't attend to padding symbols
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def do_setup(package_data):
"setuptools>=18.0",
],
install_requires=[
"cffi",
"cffi",
"cython",
"omegaconf",
'dataclasses; python_version<"3.7"',
Expand All @@ -214,6 +214,7 @@ def do_setup(package_data):
"torchaudio>=0.8.0",
"scikit-learn",
"packaging",
"rotary-embedding-torch>=0.6.4"
],
dependency_links=dependency_links,
packages=find_packages(
Expand Down

0 comments on commit 68c2516

Please sign in to comment.