Skip to content

Conversation

@JJJYmmm
Copy link
Contributor

@JJJYmmm JJJYmmm commented Apr 8, 2025

What does this PR do?

Implemented an non-mask-based window attention mechanism for the eager/sdpa version in Qwen2.5-VL.
For more details, see: Qwen2.5-VL Issue #1049

After the modifications, both memory usage and inference time have been improved, with inference time nearly halving (reduced by approximately 50%) for eager/sdpa.

This optimization may be particularly beneficial for users working with hardware like the V100 or others that do not support Flash Attention.

settings:

  • single A100 GPUs
  • image size: 1000x2530
  • repeat 20 times
**flash attn:**
Mean Inference time: 0.22 seconds
Peak GPU memory allocated: 16295.05 MB

**eager:**
Mean Inference time: 2.43 seconds
Peak GPU memory allocated: 38783.01 MB
**eager(modified):**
Mean Inference time: 1.39 seconds
Peak GPU memory allocated: 37503.01 MB

**sdpa:**
Mean Inference time: 4.30 seconds
Peak GPU memory allocated: 36929.76 MB
**sdpa(modified):**
Mean Inference time: 1.88 seconds
Peak GPU memory allocated: 36509.76 MB
import time
from PIL import Image

import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

def inference_visual(img_url, model, processor):
    image = Image.open(img_url)
    inputs = processor.image_processor(images=[image], return_tensors="pt").to('cuda')

    repeat_times = 20
    torch.cuda.reset_peak_memory_stats()

    start_time = time.time()
    for _ in range(repeat_times):
      model.visual(inputs['pixel_values'].to(model.visual.dtype), grid_thw = inputs['image_grid_thw'])
    end_time = time.time()
    elapsed_time = end_time - start_time

    gpu_peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)

    print(f"Mean Inference time: {(elapsed_time / repeat_times):.2f} seconds")
    print(f"Peak GPU memory allocated: {gpu_peak_memory:.2f} MB")


if __name__ == "__main__":
  model_path = "path/to/qwen2_5vl"
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, 
                                                              attn_implementation="eager",
                                                              # attn_implementation="sdpa",
                                                              # attn_implementation="flash_attention_2",
                                                            device_map="cpu").cuda()
  processor = AutoProcessor.from_pretrained(model_path)

  image_path = 'Qwen2.5-vl-Capybara.png' # https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2.5-vl-Capybara.png

  with torch.no_grad():
    inference_visual(image_path, model, processor)

Models:

@github-actions
Copy link
Contributor

github-actions bot commented Apr 8, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@github-actions github-actions bot marked this pull request as draft April 8, 2025 10:07
@JJJYmmm JJJYmmm marked this pull request as ready for review April 8, 2025 10:09
@JJJYmmm JJJYmmm closed this Apr 8, 2025
@ArthurZucker
Copy link
Collaborator

Hey did I miss something?

@JJJYmmm JJJYmmm reopened this Apr 8, 2025
@github-actions github-actions bot marked this pull request as draft April 8, 2025 13:32
@github-actions
Copy link
Contributor

github-actions bot commented Apr 8, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Apr 8, 2025

@ArthurZucker Hey, it's nothing. I'm currently exploring the implementation of FlexAttention.

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Apr 8, 2025

Unfortunately, torch.nn.attention.flex_attention only supports hidden sizes that are multiples of 2 now (for Qwen2.5VL-7B NaViT, it's 1280/16 = 80), so this PR will stick with the version using a for loop. 🤗

@JJJYmmm JJJYmmm marked this pull request as ready for review April 8, 2025 15:16
@ArthurZucker
Copy link
Collaborator

@JJJYmmm we can pad the inputs to match what's supported no?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great can you just run integration tests to make sure resultsa re the same? 🤗

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Apr 9, 2025

@ArthurZucker Given the potential for a large number of windows, parallelization may be more efficient than a for loop. I'll experiment with padding and perform a thorough test to ensure it maintains consistency with the original behavior. I'll update you once it is prepared! 👋

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Apr 9, 2025

Hi, I implemented the padding-based approach and compared it with the existing for loop implementation. Under various sequence lengths (seqlen) and window sizes (window size), the performance of the padding method was not significantly different from the for loop. This might be because the padding method needs to maintain extra length information for each sample. So I think the current for loop implementation is efficient enough. 🤗

Additionally, I verified the consistency of these implementations and successfully passed the fast tests by running:

pytest tests/models/qwen2_5_vl/
====================================== 121 passed, 82 skipped, 7 warnings in 233.27s (0:03:53) ======================================

Srcipts and results

import time
import math
import torch
from torch import nn

from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import apply_rotary_pos_emb_vision

from typing import Optional, Tuple

class Qwen2_5_VLVisionAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
      
    def forward_origin(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
        if position_embeddings is None:
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

        attention_mask = torch.full(
            [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
        )
        for i in range(1, len(cu_seqlens)):
            attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0

        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
        attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(0, 1)
        attn_output = attn_output.reshape(seq_length, -1)
        attn_output = self.proj(attn_output)

        return attn_output
    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
        if position_embeddings is None:
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        attn_output = []
        for i in range(len(cu_seqlens) - 1):
            start = cu_seqlens[i]
            end = cu_seqlens[i + 1]

            q_window = q[:, start:end, :]
            k_window = k[:, start:end, :]
            v_window = v[:, start:end, :]

            attn_weights = torch.matmul(q_window, k_window.transpose(1, 2)) / math.sqrt(self.head_dim)
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
            attn_output_window = torch.matmul(attn_weights, v_window)
            attn_output.append(attn_output_window)

        attn_output = torch.cat(attn_output, dim=1).transpose(0, 1)
        attn_output = attn_output.reshape(seq_length, -1)
        attn_output = self.proj(attn_output)

        return attn_output
    def forward_padding(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
        if position_embeddings is None:
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        batch_size = cu_seqlens.numel() - 1
        window_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
        max_window_length = window_lengths.max().item()

        q_padded = torch.zeros(self.num_heads, batch_size, max_window_length, self.head_dim, device=q.device, dtype=q.dtype)
        k_padded = torch.zeros(self.num_heads, batch_size, max_window_length, self.head_dim, device=k.device, dtype=k.dtype)
        v_padded = torch.zeros(self.num_heads, batch_size, max_window_length, self.head_dim, device=v.device, dtype=v.dtype)

        for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
            q_padded[:, i, :end-start, :] = q[:, start:end, :]
            k_padded[:, i, :end-start, :] = k[:, start:end, :]
            v_padded[:, i, :end-start, :] = v[:, start:end, :]

        attn_weights = torch.matmul(q_padded, k_padded.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)

        attn_output_padded = torch.matmul(attn_weights, v_padded)

        attn_output = []
        for i, length in enumerate(window_lengths):
            attn_output.append(attn_output_padded[:, i, :length, :])
        attn_output = torch.cat(attn_output, dim=1)

        attn_output = attn_output.transpose(0, 1)
        attn_output = attn_output.reshape(seq_length, -1)
        attn_output = self.proj(attn_output)

        return attn_output
def check_consistency(seq_len, block_size):
    # FAKE INPUTS

    hidden_states = torch.randn(seq_len, 1280).cuda()
    cu_seqlens = torch.tensor([i for i in range(0, seq_len + 1, block_size)], dtype=torch.int32).cuda()
    position_embeddings = (torch.randn(seq_len, 80).cuda(), torch.randn(seq_len, 80).cuda())

    # Attention
    eager_attention = Qwen2_5_VLVisionAttention(1280, 16).cuda()

    # Number of repetitions
    repetitions = 20

    # Measure time for modified forward
    modified_times = []
    for _ in range(repetitions):
        start_time = time.time()
        eager_attention(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
        modified_times.append(time.time() - start_time)

    # Measure time for origin forward
    origin_times = []
    for _ in range(repetitions):
        start_time = time.time()
        eager_attention.forward_origin(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
        origin_times.append(time.time() - start_time)

    # Measure time for padding forward
    padding_times = []
    for _ in range(repetitions):
        start_time = time.time()
        eager_attention.forward_padding(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
        padding_times.append(time.time() - start_time)

    # Calculate average times
    avg_modified_time = sum(modified_times) / repetitions
    avg_origin_time = sum(origin_times) / repetitions
    avg_padding_time = sum(padding_times) / repetitions

    # Print results
    print(f"Average time for modified forward: {avg_modified_time:.6f} seconds")
    print(f"Average time for origin forward: {avg_origin_time:.6f} seconds")
    print(f"Average time for padding forward: {avg_padding_time:.6f} seconds")

    # Consistency checks
    output_modified = eager_attention(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
    output_orgin = eager_attention.forward_origin(hidden_states, cu_seqlens, position_embeddings=position_embeddings)
    output_padding = eager_attention.forward_padding(hidden_states, cu_seqlens, position_embeddings=position_embeddings)

    assert torch.allclose(output_modified, output_orgin, atol=1e-5), "Output is not consistent between eager and origin!"
    assert torch.allclose(output_padding, output_orgin, atol=1e-5), "Output is not consistent between padding and origin!"

    print("Output is consistent!")


if __name__ == "__main__":
  with torch.no_grad():
    seq_len = [12096, 4096]
    block_size = [[32, 64, 64 * 3], [32,64]]
    for i in range(len(seq_len)):
        for j in range(len(block_size[i])):
            print(f"===== seq_len: {seq_len[i]}, block_size: {block_size[i][j]} =====")
            check_consistency(seq_len[i], block_size[i][j])
===== seq_len: 12096, block_size: 32 =====
Average time for modified forward: 0.073282 seconds
Average time for origin forward: 0.125473 seconds
Average time for padding forward: 0.093071 seconds
Output is consistent!
===== seq_len: 12096, block_size: 64 =====
Average time for modified forward: 0.038572 seconds
Average time for origin forward: 0.113410 seconds
Average time for padding forward: 0.120773 seconds
Output is consistent!
===== seq_len: 12096, block_size: 192 =====
Average time for modified forward: 0.025483 seconds
Average time for origin forward: 0.125795 seconds
Average time for padding forward: 0.040718 seconds
Output is consistent!
===== seq_len: 4096, block_size: 32 =====
Average time for modified forward: 0.038642 seconds
Average time for origin forward: 0.046402 seconds
Average time for padding forward: 0.062893 seconds
Output is consistent!
===== seq_len: 4096, block_size: 64 =====
Average time for modified forward: 0.021746 seconds
Average time for origin forward: 0.028758 seconds
Average time for padding forward: 0.029745 seconds
Output is consistent!

@JJJYmmm JJJYmmm force-pushed the main branch 2 times, most recently from 50a81dc to 397511e Compare April 9, 2025 14:28
@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Apr 9, 2025

CC @ArthurZucker all tests are passing! 😊

@JJJYmmm JJJYmmm requested a review from ArthurZucker April 10, 2025 11:59
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks 🫡
I think kernels will be a great fit long term here!

Just to check did you run this with compile (like generate with static cache?)

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Apr 11, 2025

@ArthurZucker Agree that kernels is a great long-term fit—I'll experiment and open a new PR accordingly. 🙌

Tested generation with static cache, all good here!

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented May 27, 2025

@ArthurZucker Hey, these new commits fixes the precision issue with SDPA attention reported in QwenLM/Qwen2.5-VL#1235.

I made the changes directly here since they were conflicting with this PR. I’ve also tested the models internally, and they behave as expected. 🫡

@JJJYmmm JJJYmmm changed the title Implement improved window attention in eager/sdpa version for Qwen2.5VL Fix SDPA attention precision issue in Qwen2.5-VL May 27, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Jul 2, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/glm4v', 'models/qwen2_5_omni', 'models/qwen2_5_vl', 'models/qwen2_vl']
quantizations: [] ...

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me as this fixes equivalence issues with FA2 vs sdpa/eager. I am surprised that the for-loop option is more efficient than padding hehe

Triggered slow tests, lets' see if they pass. In the meanwhile, can I ask you to check if this resolves #39067 or not?

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Jul 2, 2025

We also used a for-loop in vLLM to make it faster and save memory before my internship, haha.
For #39067, the for-loop did reduce the logit diff, but the difference is still there. I found that the diff seems to accumulate between layers.
I also wrote a script and found something weird: different attention implementations lead to a non-trivial difference, even between the eager one and torch.nn.sdpa. But I haven't found the exact reason yet. 🥹

import torch
import torch.nn as nn
import math
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from typing import Optional


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(
        attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# copy from https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                                 is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value


bs, num_heads, seq_len, head_dim = 1, 16, 1024, 128
x = torch.randn(bs, num_heads, seq_len, head_dim,
                dtype=torch.bfloat16, device="cuda")
kv = torch.randn(bs, num_heads, seq_len, head_dim,
                 dtype=torch.bfloat16, device="cuda")


class Config:
    def __init__(self):
        self._attn_implementation = None


class TestModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_key_value_groups = 1
        self.scaling = head_dim**-0.5
        self.training = False
        self.config = Config()
        self.is_causal = False


test_module = TestModule()

# test eager_attention_forward
test_module.config._attn_implementation = "eager"
input_eager = x.clone()
with torch.no_grad():
    output_eager = eager_attention_forward(
        test_module, input_eager, kv, kv, attention_mask=None, scaling=test_module.scaling, dropout=0.0)[0]


# test flash_attention_forward
test_module.config._attn_implementation = "flash_attention_2"
input_flash = x.clone()
cu_seqlens = torch.tensor(
    [0, seq_len], dtype=torch.int32, device=input_flash.device)
with torch.no_grad():
    output_flash = ALL_ATTENTION_FUNCTIONS["flash_attention_2"](test_module, input_flash, kv, kv, attention_mask=None, scaling=test_module.scaling, dropout=0.0,
                                                                cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens, max_length_q=seq_len, max_length_k=seq_len, is_causal=False)[0]

# test sdpa_attention_forward
test_module.config._attn_implementation = "sdpa"
input_sdpa = x.clone()
with torch.no_grad():
    output_sdpa = ALL_ATTENTION_FUNCTIONS["sdpa"](test_module, input_sdpa, kv, kv, attention_mask=None, scaling=test_module.scaling, dropout=0.0,
                                                  is_causal=False)[0]
    output_sdpa_manual = scaled_dot_product_attention(input_sdpa, kv, kv,
                                                      attn_mask=None, dropout_p=0.0, is_causal=False, scale=test_module.scaling, enable_gqa=False)
    output_sdpa_manual = output_sdpa_manual.transpose(1, 2).contiguous()

print("sdpa and sdpa_manual outputs are equal:",
      torch.allclose(output_sdpa, output_sdpa_manual, atol=1e-3))
print("eager and flash_attention_2 outputs are equal:",
      torch.allclose(output_eager, output_flash, atol=1e-3))
print("eager and sdpa outputs are equal:",
      torch.allclose(output_eager, output_sdpa, atol=1e-3))
print("flash_attention_2 and sdpa outputs are equal:",
      torch.allclose(output_flash, output_sdpa, atol=1e-3))

print("sdpa_manual and eager outputs are equal:",
      torch.allclose(output_sdpa_manual, output_eager, atol=1e-3))
print("sdpa_manual and flash_attention_2 outputs are equal:",
      torch.allclose(output_sdpa_manual, output_flash, atol=1e-3))
sdpa and sdpa_manual outputs are equal: False
eager and flash_attention_2 outputs are equal: False
eager and sdpa outputs are equal: False
flash_attention_2 and sdpa outputs are equal: False
sdpa_manual and eager outputs are equal: True
sdpa_manual and flash_attention_2 outputs are equal: False

Update: I think my comparison method was too strict because allclose checks every element. The outputs look much more reasonable when using relative mean difference. 🥲

eager and flash_attention_2 Relative mean diff: 0.005432
eager and sdpa Relative mean diff: 0.005432
eager and sdpa_manual Relative mean diff: 0.000000
flash_attention_2 and sdpa Relative mean diff: 0.000001
flash_attention_2 and sdpa_manual Relative mean diff: 0.005432
sdpa and sdpa_manual Relative mean diff: 0.005432

@zucchini-nlp
Copy link
Member

Interesting, i guess Qwen-ViT is indeed not very stable. Usually we don't have a huge difference with other models when comparing different attention implementations

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Jul 7, 2025

Yes, exactly. I think Qwen-ViT is unstable because its image feature values are so much larger than the text embedding values. We actually measured this before: going into the first LM layer, the image feature norm is about 30x the text embedding norm. We're working on a fix for this. 🧐

@zucchini-nlp
Copy link
Member

run-slow: glm4v, qwen2_5_omni, qwen2_5_vl, qwen2_vl

@github-actions
Copy link
Contributor

github-actions bot commented Jul 8, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/glm4v', 'models/qwen2_5_omni', 'models/qwen2_5_vl', 'models/qwen2_vl']
quantizations: [] ...

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

tests/models/glm4v/test_modeling_glm4v.py::Glm4vIntegrationTest::test_small_model_integration_test_batch_different_resolutions is failing, not sure if it is related. Other slow tests aren't from torchscript and thus not related

@github-actions
Copy link
Contributor

github-actions bot commented Jul 8, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: glm4v, qwen2_5_omni, qwen2_5_vl, qwen2_vl

@JJJYmmm
Copy link
Contributor Author

JJJYmmm commented Jul 8, 2025

Thanks for pointing this out! The test was failing because the logits changed when we switched the SDPA implementation from padding to windowing. It should be resolved now. 🤗
截屏2025-07-08 19 44 27

@zucchini-nlp
Copy link
Member

run-slow: glm4v

@github-actions
Copy link
Contributor

github-actions bot commented Jul 8, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/glm4v']
quantizations: [] ...

@zucchini-nlp
Copy link
Member

Tests green, can be merged

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@zucchini-nlp zucchini-nlp merged commit 25343aa into huggingface:main Jul 9, 2025
20 checks passed
rjgleaton pushed a commit to rjgleaton/transformers that referenced this pull request Jul 17, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* solve conflicts and remove  redundant attention_mask in qwenvit

* update decoded text check

* remove trailing whitespace
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants