Skip to content

Qwen2.5-VL using ascend NPU with flash-attention-2 raises error #38189

@llan-ml

Description

@llan-ml

System Info

  • transformers version: 4.52.0.dev0
  • Platform: Linux-4.19.90-vhulk2211.3.0.h1543.eulerosv2r10.aarch64-aarch64-with-glibc2.31
  • Python version: 3.10.5
  • Huggingface_hub version: 0.30.2
  • Safetensors version: 0.5.3
  • Accelerate version: 1.6.0
  • Accelerate config: not found
  • DeepSpeed version: 0.16.7
  • PyTorch version (GPU?): 2.3.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using NPU in script?:
  • NPU type: Ascend910B4
  • CANN version: 8.0.0

Who can help?

@FightingZhen

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

reproduction script:

import os

os.environ["NPU_VISIBLE_DEVICES"]="0"
os.environ["ASCEND_RT_VISIBLE_DEVICES"]="0"

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

# default: Load the model on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "/cache/Qwen2.5-VL-7B-Instruct/", torch_dtype="auto", device_map="auto", attn_implementation="flash_attention_2"
)

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2.5-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )

# default processer
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

# The default range for the number of visual tokens per image in the model is 4-16384.
# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
min_pixels = 256*28*28
max_pixels = 1280*28*28
processor = AutoProcessor.from_pretrained("/cache/Qwen2.5-VL-7B-Instruct/", min_pixels=min_pixels, max_pixels=max_pixels)

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "file:///home/ma-user/work/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("npu")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)

log:

/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/utils/path_manager.py:82: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/latest owner does not match the current user.
  warnings.warn(f"Warning: The {path} owner does not match the current user.")
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/utils/path_manager.py:82: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/8.0.0/aarch64-linux/ascend_toolkit_install.info owner does not match the current user.
  warnings.warn(f"Warning: The {path} owner does not match the current user.")
[W compiler_depend.ts:615] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.03it/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[W compiler_depend.ts:51] Warning: CAUTION: The operator 'aten::isin.Tensor_Tensor_out' is not currently supported on the NPU backend and will fall back to run on the CPU. This may have performance implications. (function npu_cpu_fallback)
Traceback (most recent call last):
  File "/home/ma-user/work/test_qwen25vl.py", line 59, in <module>
    generated_ids = model.generate(**inputs, max_new_tokens=128)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/generation/utils.py", line 2592, in generate
    result = self._sample(
  File "/home/ma-user/work/transformers/src/transformers/generation/utils.py", line 3552, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1907, in forward
    outputs = self.model(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1660, in forward
    image_embeds = self.get_image_features(pixel_values, image_grid_thw)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1613, in get_image_features
    image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 530, in forward
    hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 341, in forward
    hidden_states = hidden_states + self.attn(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 189, in forward
    q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
  File "/home/ma-user/work/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 156, in apply_rotary_pos_emb_flashatt
    q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/_ops.py", line 854, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: The size of tensor a (40) must match the size of tensor b (80) at non-singleton dimension 3
[ERROR] 2025-05-18-14:48:37 (PID:123684, Device:0, RankID:-1) ERR99999 UNKNOWN application exception

Expected behavior

No error

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions