Skip to content

Commit

Permalink
Merge pull request #33 from tomaarsen/qwen/update_registered_causal_mask
Browse files Browse the repository at this point in the history
Update QWen due to changes in the modeling files of QWen-7b
  • Loading branch information
tomaarsen authored Nov 23, 2023
2 parents 34d071c + 1ef7892 commit a9ef1e6
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions attention_sinks/models/qwen/pos_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ def qwen_pos_shift_attention_forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
if registered_causal_mask is None:
raise ValueError(
"Attention Sinks does not support Flash Attention in QWen models, please use `use_flash_attn=False` in `AutoModelForCausalLM.from_pretrained`."
)

mixed_x_layer = self.c_attn(hidden_states)

query, key, value = mixed_x_layer.split(self.split_size, dim=2)
Expand Down Expand Up @@ -96,6 +91,9 @@ def qwen_pos_shift_attention_forward(
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)

registered_causal_mask = torch.tril(
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
).view(1, 1, key.size(1), key.size(1))
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
Expand Down

0 comments on commit a9ef1e6

Please sign in to comment.