Skip to content

Commit f1f41ad

Browse files
zixi-qihouseroad
authored andcommitted
[bug fix] Fix llama4 spec decoding (vllm-project#22691)
Signed-off-by: qizixi <[email protected]> Co-authored-by: Lu Fang <[email protected]>
1 parent 9360372 commit f1f41ad

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def __init__(self,
195195
is_neox_style=is_neox_style,
196196
) if not self.nope else None
197197

198-
attn_cls = Attention if self.nope else ChunkedLocalAttention
198+
use_chunked_local_attn = not self.nope and config.attention_chunk_size
199+
attn_cls = (ChunkedLocalAttention
200+
if use_chunked_local_attn else Attention)
199201
self.attn = attn_cls(
200202
self.num_heads,
201203
self.head_dim,
@@ -206,7 +208,7 @@ def __init__(self,
206208
prefix=f"{prefix}.attn",
207209
**({
208210
"attention_chunk_size": config.attention_chunk_size
209-
} if not self.nope else {}))
211+
} if use_chunked_local_attn else {}))
210212

211213
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
212214
floor = torch.floor((positions + 1.0) / self.floor_scale)

0 commit comments

Comments
 (0)