Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def build(
slot_mapping = common_attn_metadata.slot_mapping

use_cascade = common_prefix_len > 0
prefix_scheduler_metadata = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the variable if it is universally None?
This pattern exists in the triton_attn.py as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable needs to exist because the constructor at line 153 explicitly passes prefix_scheduler_metadata=prefix_scheduler_metadata. When use_cascade=True, the if branch runs and the variable is never defined — Python raises UnboundLocalError.

An alternative fix would be to remove the explicit kwarg from the constructor and let the dataclass default (= None) handle it — similar to how scheduler_metadata is already handled. However, pre-initializing before the conditional matches the pattern in flash_attn.py (line 427), which later assigns a real tensor via schedule() in the cascade path (line 465). This keeps the code forward-compatible for when the ROCm backend adopts AOT scheduling.

Regarding triton_attn.py — it has the same latent bug (line 231: only initialized in the else branch, passed explicitly at line 246). Happy to include a fix for it in this PR or as a follow-up.


if use_cascade:
cu_prefix_query_lens = torch.tensor(
Expand All @@ -135,7 +136,6 @@ def build(
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None

attn_metadata = RocmAttentionMetadata(
num_actual_tokens=num_actual_tokens,
Expand Down