Skip to content

[Bug]: tests/v1/tpu/test_sampler.py crashes due to ragged_paged_attention arg mismatch #15257

@hyeygit

Description

@hyeygit

Your current environment

Environment Collecting environment information... PyTorch version: 2.7.0a0+gited9c8a5 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.31

Python version: 3.10.16 (main, Feb 4 2025, 07:26:46) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-6.8.0-1015-gcp-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 52 bits physical, 57 bits virtual
CPU(s): 44
On-line CPU(s) list: 0-43
Thread(s) per core: 2
Core(s) per socket: 22
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 25
Model: 17
Model name: AMD EPYC 9B14
Stepping: 1
CPU MHz: 2599.996
BogoMIPS: 5199.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 704 KiB
L1i cache: 704 KiB
L2 cache: 22 MiB
L3 cache: 64 MiB
NUMA node0 CPU(s): 0-43
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pyzmq==26.2.1
[pip3] torch==2.7.0a0+gited9c8a5
[pip3] torch-xla==2.7.0+git2c70a1c
[pip3] transformers==4.49.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.7.4.dev240+g3a6f9906
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

LD_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/cv2/../../lib64::/usr/local/lib
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1

🐛 Describe the bug

To reproduce, run

VLLM_USE_V1=1 pytest -s tests/v1/tpu/test_sampler.py

Test killed (crashed) due to:

...
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/experimental/custom_kernel.py", line 906, in validate_ragged_paged_attention_inputs
ERROR 03-21 00:10:17 [core.py:330]     _, _, num_kv_heads, head_dim_k = k_pages.shape
ERROR 03-21 00:10:17 [core.py:330] ValueError: not enough values to unpack (expected 4, got 3)
...
Full stack trace
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/engine/core.py", line 323, in run_engine_core
ERROR 03-21 00:10:17 [core.py:330]     engine_core.run_busy_loop()
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/engine/core.py", line 357, in run_busy_loop
ERROR 03-21 00:10:17 [core.py:330]     outputs = step_fn()
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/engine/core.py", line 182, in step
ERROR 03-21 00:10:17 [core.py:330]     output = self.model_executor.execute_model(scheduler_output)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/executor/abstract.py", line 80, in execute_model
ERROR 03-21 00:10:17 [core.py:330]     output = self.collective_rpc("execute_model",
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 03-21 00:10:17 [core.py:330]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/utils.py", line 2221, in run_method
ERROR 03-21 00:10:17 [core.py:330]     return func(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/worker/tpu_worker.py", line 166, in execute_model
ERROR 03-21 00:10:17 [core.py:330]     output = self.model_runner.execute_model(scheduler_output)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-21 00:10:17 [core.py:330]     return func(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/worker/tpu_model_runner.py", line 605, in execute_model
ERROR 03-21 00:10:17 [core.py:330]     hidden_states = self.model(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 03-21 00:10:17 [core.py:330]     return self._call_impl(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 03-21 00:10:17 [core.py:330]     return forward_call(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_dynamo/eval_frame.py", line 655, in _fn
ERROR 03-21 00:10:17 [core.py:330]     return fn(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 03-21 00:10:17 [core.py:330]     return self._call_impl(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 03-21 00:10:17 [core.py:330]     return forward_call(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/worker/tpu_model_runner.py", line 866, in forward
ERROR 03-21 00:10:17 [core.py:330]     def forward(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_dynamo/eval_frame.py", line 838, in _fn
ERROR 03-21 00:10:17 [core.py:330]     return fn(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_functorch/aot_autograd.py", line 1201, in forward
ERROR 03-21 00:10:17 [core.py:330]     return compiled_fn(full_args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper
ERROR 03-21 00:10:17 [core.py:330]     all_outs = call_func_at_runtime_with_args(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
ERROR 03-21 00:10:17 [core.py:330]     out = normalize_as_list(f(args))
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
ERROR 03-21 00:10:17 [core.py:330]     outs = compiled_fn(args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
ERROR 03-21 00:10:17 [core.py:330]     return compiled_fn(runtime_args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_functorch/_aot_autograd/utils.py", line 100, in g
ERROR 03-21 00:10:17 [core.py:330]     return f(*args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_dynamo/backends/torchxla.py", line 37, in fwd
ERROR 03-21 00:10:17 [core.py:330]     compiled_graph = bridge.extract_compiled_graph(model, args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 709, in extract_compiled_graph
ERROR 03-21 00:10:17 [core.py:330]     return extract_compiled_graph_helper(xla_model, xla_args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 830, in extract_compiled_graph_helper
ERROR 03-21 00:10:17 [core.py:330]     return partition_fx_graph_for_cpu_fallback(xla_model, xla_args,
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 734, in partition_fx_graph_for_cpu_fallback
ERROR 03-21 00:10:17 [core.py:330]     collector.run(*xla_args)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/fx/interpreter.py", line 171, in run
ERROR 03-21 00:10:17 [core.py:330]     self.env[node] = self.run_node(node)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 627, in run_node
ERROR 03-21 00:10:17 [core.py:330]     result = super().run_node(n)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/fx/interpreter.py", line 240, in run_node
ERROR 03-21 00:10:17 [core.py:330]     return getattr(self, n.op)(n.target, args, kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/fx/interpreter.py", line 320, in call_function
ERROR 03-21 00:10:17 [core.py:330]     return target(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/torch/_ops.py", line 756, in __call__
ERROR 03-21 00:10:17 [core.py:330]     return self._op(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/experimental/custom_kernel.py", line 1751, in ragged_paged_attention_xla
ERROR 03-21 00:10:17 [core.py:330]     return ragged_paged_attention(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/experimental/custom_kernel.py", line 181, in wrapper
ERROR 03-21 00:10:17 [core.py:330]     return func(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/experimental/custom_kernel.py", line 1012, in ragged_paged_attention
ERROR 03-21 00:10:17 [core.py:330]     validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens,
ERROR 03-21 00:10:17 [core.py:330]   File "/root/pytorch/xla/torch_xla/experimental/custom_kernel.py", line 906, in validate_ragged_paged_attention_inputs
ERROR 03-21 00:10:17 [core.py:330]     _, _, num_kv_heads, head_dim_k = k_pages.shape
ERROR 03-21 00:10:17 [core.py:330] ValueError: not enough values to unpack (expected 4, got 3)

...

ERROR 03-21 00:10:17 [core.py:330] Original traceback:
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/worker/tpu_model_runner.py", line 884, in forward
ERROR 03-21 00:10:17 [core.py:330]     hidden_states = self.model(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/model_executor/models/qwen2.py", line 462, in forward
ERROR 03-21 00:10:17 [core.py:330]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/compilation/decorators.py", line 172, in __call__
ERROR 03-21 00:10:17 [core.py:330]     return self.forward(*args, **kwargs)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/model_executor/models/qwen2.py", line 338, in forward
ERROR 03-21 00:10:17 [core.py:330]     hidden_states, residual = layer(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/model_executor/models/qwen2.py", line 243, in forward
ERROR 03-21 00:10:17 [core.py:330]     hidden_states = self.self_attn(
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/model_executor/models/qwen2.py", line 177, in forward
ERROR 03-21 00:10:17 [core.py:330]     attn_output = self.attn(q, k, v)
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/attention/layer.py", line 226, in forward
ERROR 03-21 00:10:17 [core.py:330]     return self.impl.forward(self, query, key, value,
ERROR 03-21 00:10:17 [core.py:330]   File "/root/vllm-official/vllm/v1/attention/backends/pallas.py", line 166, in forward
ERROR 03-21 00:10:17 [core.py:330]     output = torch.ops.xla.ragged_paged_attention(

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions