-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Fix SDPA attention precision issue in Qwen2.5-VL #37363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
Hey did I miss something? |
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
@ArthurZucker Hey, it's nothing. I'm currently exploring the implementation of |
|
Unfortunately, |
|
@JJJYmmm we can pad the inputs to match what's supported no? |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great can you just run integration tests to make sure resultsa re the same? 🤗
|
@ArthurZucker Given the potential for a large number of windows, parallelization may be more efficient than a for loop. I'll experiment with padding and perform a thorough test to ensure it maintains consistency with the original behavior. I'll update you once it is prepared! 👋 |
|
Hi, I implemented the padding-based approach and compared it with the existing for loop implementation. Under various sequence lengths (seqlen) and window sizes (window size), the performance of the padding method was not significantly different from the for loop. This might be because the padding method needs to maintain extra length information for each sample. So I think the current for loop implementation is efficient enough. 🤗 Additionally, I verified the consistency of these implementations and successfully passed the fast tests by running: Srcipts and results import time
import math
import torch
from torch import nn
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import apply_rotary_pos_emb_vision
from typing import Optional, Tuple
class Qwen2_5_VLVisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward_origin(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = []
for i in range(len(cu_seqlens) - 1):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
q_window = q[:, start:end, :]
k_window = k[:, start:end, :]
v_window = v[:, start:end, :]
attn_weights = torch.matmul(q_window, k_window.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output_window = torch.matmul(attn_weights, v_window)
attn_output.append(attn_output_window)
attn_output = torch.cat(attn_output, dim=1).transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
def forward_padding(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
batch_size = cu_seqlens.numel() - 1
window_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
max_window_length = window_lengths.max().item()
q_padded = torch.zeros(self.num_heads, batch_size, max_window_length, self.head_dim, device=q.device, dtype=q.dtype)
k_padded = torch.zeros(self.num_heads, batch_size, max_window_length, self.head_dim, device=k.device, dtype=k.dtype)
v_padded = torch.zeros(self.num_heads, batch_size, max_window_length, self.head_dim, device=v.device, dtype=v.dtype)
for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
q_padded[:, i, :end-start, :] = q[:, start:end, :]
k_padded[:, i, :end-start, :] = k[:, start:end, :]
v_padded[:, i, :end-start, :] = v[:, start:end, :]
attn_weights = torch.matmul(q_padded, k_padded.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output_padded = torch.matmul(attn_weights, v_padded)
attn_output = []
for i, length in enumerate(window_lengths):
attn_output.append(attn_output_padded[:, i, :length, :])
attn_output = torch.cat(attn_output, dim=1)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
def check_consistency(seq_len, block_size):
# FAKE INPUTS
hidden_states = torch.randn(seq_len, 1280).cuda()
cu_seqlens = torch.tensor([i for i in range(0, seq_len + 1, block_size)], dtype=torch.int32).cuda()
position_embeddings = (torch.randn(seq_len, 80).cuda(), torch.randn(seq_len, 80).cuda())
# Attention
eager_attention = Qwen2_5_VLVisionAttention(1280, 16).cuda()
# Number of repetitions
repetitions = 20
# Measure time for modified forward
modified_times = []
for _ in range(repetitions):
start_time = time.time()
eager_attention(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
modified_times.append(time.time() - start_time)
# Measure time for origin forward
origin_times = []
for _ in range(repetitions):
start_time = time.time()
eager_attention.forward_origin(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
origin_times.append(time.time() - start_time)
# Measure time for padding forward
padding_times = []
for _ in range(repetitions):
start_time = time.time()
eager_attention.forward_padding(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
padding_times.append(time.time() - start_time)
# Calculate average times
avg_modified_time = sum(modified_times) / repetitions
avg_origin_time = sum(origin_times) / repetitions
avg_padding_time = sum(padding_times) / repetitions
# Print results
print(f"Average time for modified forward: {avg_modified_time:.6f} seconds")
print(f"Average time for origin forward: {avg_origin_time:.6f} seconds")
print(f"Average time for padding forward: {avg_padding_time:.6f} seconds")
# Consistency checks
output_modified = eager_attention(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
output_orgin = eager_attention.forward_origin(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
output_padding = eager_attention.forward_padding(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
assert torch.allclose(output_modified, output_orgin, atol=1e-5), "Output is not consistent between eager and origin!"
assert torch.allclose(output_padding, output_orgin, atol=1e-5), "Output is not consistent between padding and origin!"
print("Output is consistent!")
if __name__ == "__main__":
with torch.no_grad():
seq_len = [12096, 4096]
block_size = [[32, 64, 64 * 3], [32,64]]
for i in range(len(seq_len)):
for j in range(len(block_size[i])):
print(f"===== seq_len: {seq_len[i]}, block_size: {block_size[i][j]} =====")
check_consistency(seq_len[i], block_size[i][j])===== seq_len: 12096, block_size: 32 =====
Average time for modified forward: 0.073282 seconds
Average time for origin forward: 0.125473 seconds
Average time for padding forward: 0.093071 seconds
Output is consistent!
===== seq_len: 12096, block_size: 64 =====
Average time for modified forward: 0.038572 seconds
Average time for origin forward: 0.113410 seconds
Average time for padding forward: 0.120773 seconds
Output is consistent!
===== seq_len: 12096, block_size: 192 =====
Average time for modified forward: 0.025483 seconds
Average time for origin forward: 0.125795 seconds
Average time for padding forward: 0.040718 seconds
Output is consistent!
===== seq_len: 4096, block_size: 32 =====
Average time for modified forward: 0.038642 seconds
Average time for origin forward: 0.046402 seconds
Average time for padding forward: 0.062893 seconds
Output is consistent!
===== seq_len: 4096, block_size: 64 =====
Average time for modified forward: 0.021746 seconds
Average time for origin forward: 0.028758 seconds
Average time for padding forward: 0.029745 seconds
Output is consistent! |
50a81dc to
397511e
Compare
|
CC @ArthurZucker all tests are passing! 😊 |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks 🫡
I think kernels will be a great fit long term here!
Just to check did you run this with compile (like generate with static cache?)
|
@ArthurZucker Agree that Tested generation with static cache, all good here! |
|
@ArthurZucker Hey, these new commits fixes the precision issue with SDPA attention reported in QwenLM/Qwen2.5-VL#1235. I made the changes directly here since they were conflicting with this PR. I’ve also tested the models internally, and they behave as expected. 🫡 |
|
This comment contains run-slow, running the specified jobs: models: ['models/glm4v', 'models/qwen2_5_omni', 'models/qwen2_5_vl', 'models/qwen2_vl'] |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me as this fixes equivalence issues with FA2 vs sdpa/eager. I am surprised that the for-loop option is more efficient than padding hehe
Triggered slow tests, lets' see if they pass. In the meanwhile, can I ask you to check if this resolves #39067 or not?
|
We also used a for-loop in vLLM to make it faster and save memory before my internship, haha. import torch
import torch.nn as nn
import math
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from typing import Optional
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# copy from https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
bs, num_heads, seq_len, head_dim = 1, 16, 1024, 128
x = torch.randn(bs, num_heads, seq_len, head_dim,
dtype=torch.bfloat16, device="cuda")
kv = torch.randn(bs, num_heads, seq_len, head_dim,
dtype=torch.bfloat16, device="cuda")
class Config:
def __init__(self):
self._attn_implementation = None
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.num_key_value_groups = 1
self.scaling = head_dim**-0.5
self.training = False
self.config = Config()
self.is_causal = False
test_module = TestModule()
# test eager_attention_forward
test_module.config._attn_implementation = "eager"
input_eager = x.clone()
with torch.no_grad():
output_eager = eager_attention_forward(
test_module, input_eager, kv, kv, attention_mask=None, scaling=test_module.scaling, dropout=0.0)[0]
# test flash_attention_forward
test_module.config._attn_implementation = "flash_attention_2"
input_flash = x.clone()
cu_seqlens = torch.tensor(
[0, seq_len], dtype=torch.int32, device=input_flash.device)
with torch.no_grad():
output_flash = ALL_ATTENTION_FUNCTIONS["flash_attention_2"](test_module, input_flash, kv, kv, attention_mask=None, scaling=test_module.scaling, dropout=0.0,
cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens, max_length_q=seq_len, max_length_k=seq_len, is_causal=False)[0]
# test sdpa_attention_forward
test_module.config._attn_implementation = "sdpa"
input_sdpa = x.clone()
with torch.no_grad():
output_sdpa = ALL_ATTENTION_FUNCTIONS["sdpa"](test_module, input_sdpa, kv, kv, attention_mask=None, scaling=test_module.scaling, dropout=0.0,
is_causal=False)[0]
output_sdpa_manual = scaled_dot_product_attention(input_sdpa, kv, kv,
attn_mask=None, dropout_p=0.0, is_causal=False, scale=test_module.scaling, enable_gqa=False)
output_sdpa_manual = output_sdpa_manual.transpose(1, 2).contiguous()
print("sdpa and sdpa_manual outputs are equal:",
torch.allclose(output_sdpa, output_sdpa_manual, atol=1e-3))
print("eager and flash_attention_2 outputs are equal:",
torch.allclose(output_eager, output_flash, atol=1e-3))
print("eager and sdpa outputs are equal:",
torch.allclose(output_eager, output_sdpa, atol=1e-3))
print("flash_attention_2 and sdpa outputs are equal:",
torch.allclose(output_flash, output_sdpa, atol=1e-3))
print("sdpa_manual and eager outputs are equal:",
torch.allclose(output_sdpa_manual, output_eager, atol=1e-3))
print("sdpa_manual and flash_attention_2 outputs are equal:",
torch.allclose(output_sdpa_manual, output_flash, atol=1e-3))sdpa and sdpa_manual outputs are equal: False
eager and flash_attention_2 outputs are equal: False
eager and sdpa outputs are equal: False
flash_attention_2 and sdpa outputs are equal: False
sdpa_manual and eager outputs are equal: True
sdpa_manual and flash_attention_2 outputs are equal: FalseUpdate: I think my comparison method was too strict because eager and flash_attention_2 Relative mean diff: 0.005432
eager and sdpa Relative mean diff: 0.005432
eager and sdpa_manual Relative mean diff: 0.000000
flash_attention_2 and sdpa Relative mean diff: 0.000001
flash_attention_2 and sdpa_manual Relative mean diff: 0.005432
sdpa and sdpa_manual Relative mean diff: 0.005432 |
|
Interesting, i guess Qwen-ViT is indeed not very stable. Usually we don't have a huge difference with other models when comparing different attention implementations |
|
Yes, exactly. I think Qwen-ViT is unstable because its image feature values are so much larger than the text embedding values. We actually measured this before: going into the first LM layer, the image feature norm is about 30x the text embedding norm. We're working on a fix for this. 🧐 |
|
run-slow: glm4v, qwen2_5_omni, qwen2_5_vl, qwen2_vl |
|
This comment contains run-slow, running the specified jobs: models: ['models/glm4v', 'models/qwen2_5_omni', 'models/qwen2_5_vl', 'models/qwen2_vl'] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
|
[For maintainers] Suggested jobs to run (before merge) run-slow: glm4v, qwen2_5_omni, qwen2_5_vl, qwen2_vl |
|
run-slow: glm4v |
|
This comment contains run-slow, running the specified jobs: models: ['models/glm4v'] |
|
Tests green, can be merged |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace

What does this PR do?
Implemented an non-mask-based window attention mechanism for the eager/sdpa version in Qwen2.5-VL.
For more details, see: Qwen2.5-VL Issue #1049
After the modifications, both memory usage and inference time have been improved, with inference time nearly halving (reduced by approximately 50%) for eager/sdpa.
This optimization may be particularly beneficial for users working with hardware like the V100 or others that do not support Flash Attention.
settings:
Models: