Skip to content

Commit 5c03727

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent f6e97c0 commit 5c03727

File tree

3 files changed

+56
-80
lines changed

3 files changed

+56
-80
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import vllm.envs as envs
1515
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1616
AttentionType)
17-
from vllm.config import VllmConfig
17+
from vllm.attention.layer import Attention
18+
from vllm.config import VllmConfig, get_layers_from_vllm_config
1819
from vllm.logger import init_logger
1920
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2021
from vllm.v1.attention.backends.utils import (
@@ -23,7 +24,6 @@
2324
from vllm.v1.kv_cache_interface import AttentionSpec
2425

2526
if TYPE_CHECKING:
26-
from vllm.config import VllmConfig
2727
from vllm.v1.core.sched.output import SchedulerOutput
2828
from vllm.v1.worker.gpu_input_batch import InputBatch
2929

@@ -109,27 +109,21 @@ def get_per_layer_parameters(
109109
Scan all attention layers and determine some hyperparameters
110110
to use during `plan`.
111111
"""
112-
model_config = vllm_config.model_config
112+
layers = get_layers_from_vllm_config(vllm_config, Attention)
113+
per_layer_params: dict[str, PerLayerParameters] = {}
113114

114-
# This is a workaround for the fact that the attention backend
115-
# in this standalone test does not have access to the full model,
116-
# so we mock the layer access.
117-
if not hasattr(model_config, 'get_num_layers'):
118-
raise RuntimeError(
119-
"The model config does not have a get_num_layers method. "
120-
"This is required for the FlashInfer backend.")
115+
for key, layer in layers.items():
116+
impl = layer.impl
117+
assert isinstance(impl, FlashInferImpl)
121118

122-
per_layer_params: dict[str, PerLayerParameters] = {}
123-
for i in range(model_config.get_num_layers()):
124-
sliding_window = model_config.get_sliding_window_for_layer(i)
125-
logits_soft_cap = model_config.get_logits_soft_cap_for_layer(i)
126-
sm_scale = model_config.get_sm_scale_for_layer(i)
127-
128-
per_layer_params[str(i)] = PerLayerParameters(
129-
sm_scale=sm_scale,
130-
logits_soft_cap=logits_soft_cap,
131-
window_left=sliding_window if sliding_window is not None else -1,
132-
)
119+
# Infer hyperparameters from the attention layer
120+
window_size = impl.sliding_window
121+
window_left = window_size[0] if window_size is not None else -1
122+
logits_soft_cap = impl.logits_soft_cap
123+
sm_scale = impl.scale
124+
125+
per_layer_params[key] = PerLayerParameters(window_left,
126+
logits_soft_cap, sm_scale)
133127

134128
return per_layer_params
135129

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
1313
CommonAttentionMetadata)
1414
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
15-
from vllm.v1.worker.block_table import BlockTable
1615

1716
if TYPE_CHECKING:
1817
from vllm.v1.core.sched.output import SchedulerOutput
1918
from vllm.v1.worker.gpu_input_batch import InputBatch
20-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2119

2220

2321
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
@@ -58,13 +56,11 @@ class Mamba2AttentionMetadata:
5856
class Mamba2AttentionMetadataBuilder(
5957
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
6058

61-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
62-
block_table: BlockTable):
59+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
60+
device: torch.device):
6361
assert isinstance(kv_cache_spec, MambaSpec)
64-
self.runner = runner
6562
self.kv_cache_spec = kv_cache_spec
66-
self.block_table = block_table
67-
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
63+
self.chunk_size = get_mamba2_chunk_size(vllm_config)
6864

6965
def reorder_batch(self, input_batch: "InputBatch",
7066
scheduler_output: "SchedulerOutput") -> bool:
@@ -143,15 +139,14 @@ def build(self,
143139
has_initial_states = None
144140
prep_initial_states = False
145141

146-
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
142+
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
147143

148144
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
149145
if self._num_prefills > 0:
150146
#[batch,]
151147
has_initial_states_cpu = (
152-
self.runner.input_batch.
153-
num_computed_tokens_cpu_tensor[num_reqs -
154-
self._num_prefills:num_reqs]
148+
common_attn_metadata.
149+
num_computed_tokens_cpu[num_reqs - self._num_prefills:num_reqs]
155150
> 0)
156151
prep_initial_states = torch.any(has_initial_states_cpu).item()
157152
has_initial_states = has_initial_states_cpu.to(

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,21 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer with AiterFlashAttention."""
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Optional
5+
from typing import Any, Optional
66

77
import torch
88

99
from vllm import _custom_ops as ops
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionMetadata, AttentionType,
1212
is_quantized_kv_cache)
13+
from vllm.config import VllmConfig
1314
from vllm.logger import init_logger
1415
from vllm.platforms import current_platform
1516
from vllm.v1.attention.backends.flash_attn import (
1617
make_local_attention_virtual_batches)
1718
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1819
from vllm.v1.kv_cache_interface import AttentionSpec
19-
from vllm.v1.worker.block_table import BlockTable
20-
21-
if TYPE_CHECKING:
22-
from vllm.v1.core.sched.output import SchedulerOutput
23-
from vllm.v1.worker.gpu_input_batch import InputBatch
24-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2520

2621
if current_platform.is_rocm():
2722
import aiter
@@ -172,56 +167,49 @@ def flash_attn_varlen_func_fake(
172167

173168
class AiterFlashAttentionMetadataBuilder:
174169

175-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
176-
block_table: BlockTable):
177-
model_config = runner.model_config
178-
179-
self.runner = runner
180-
self.num_heads_q = model_config.get_num_attention_heads(
181-
runner.parallel_config)
182-
self.num_heads_kv = model_config.get_num_kv_heads(
183-
runner.parallel_config)
184-
self.headdim = model_config.get_head_size()
170+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
171+
device: torch.device):
172+
self.vllm_config = vllm_config
173+
self.model_config = vllm_config.model_config
174+
self.parallel_config = vllm_config.parallel_config
175+
self.cache_config = vllm_config.cache_config
176+
self.device = device
177+
178+
self.num_heads_q = self.model_config.get_num_attention_heads(
179+
self.parallel_config)
180+
self.num_heads_kv = self.model_config.get_num_kv_heads(
181+
self.parallel_config)
182+
self.headdim = self.model_config.get_head_size()
185183
self.block_size = kv_cache_spec.block_size
186184
self.kv_cache_spec = kv_cache_spec
187-
self.block_table = block_table
188185

189186
# Sliding window size to be used with the AOT scheduler will be
190187
# populated on first build() call.
191188
self.aot_sliding_window: Optional[tuple[int, int]] = None
192189

193-
def reorder_batch(self, input_batch: "InputBatch",
194-
scheduler_output: "SchedulerOutput") -> bool:
190+
def reorder_batch(self, input_batch, scheduler_output) -> bool:
195191
return False
196192

197193
def build(self,
198194
common_prefix_len: int,
199195
common_attn_metadata: CommonAttentionMetadata,
200196
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
201197

202-
num_reqs = common_attn_metadata.num_reqs
203198
num_actual_tokens = common_attn_metadata.num_actual_tokens
204199
max_query_len = common_attn_metadata.max_query_len
205200

206-
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
207-
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
201+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max().item())
202+
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum().item())
208203
query_start_loc = common_attn_metadata.query_start_loc
204+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
209205
seq_lens = common_attn_metadata.seq_lens
210-
block_table = self.block_table
211-
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
212-
213-
block_table.slot_mapping[:num_actual_tokens].copy_(
214-
block_table.slot_mapping_cpu[:num_actual_tokens],
215-
non_blocking=True)
216-
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
217-
# mode.
218-
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
219-
220-
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
206+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
207+
block_table_tensor = common_attn_metadata.block_table_tensor
208+
slot_mapping = common_attn_metadata.slot_mapping
221209

222210
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
223211
dtype=torch.int32,
224-
device="cuda")
212+
device=self.device)
225213
torch.cumsum(seq_lens,
226214
dim=0,
227215
dtype=cu_seq_lens.dtype,
@@ -233,21 +221,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
233221

234222
# for local attention
235223
local_attn_metadata = None
236-
if self.runner.attention_chunk_size is not None:
224+
if self.model_config.attention_chunk_size is not None:
237225
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
238226
virt_block_table_tensor = make_local_attention_virtual_batches(
239-
self.runner.attention_chunk_size,
240-
self.runner.query_start_loc_np[:num_reqs + 1],
241-
self.runner.seq_lens_np[:num_reqs],
227+
self.model_config.attention_chunk_size,
228+
query_start_loc_cpu.numpy(),
229+
seq_lens_cpu.numpy(),
242230
block_table_tensor,
243231
self.block_size,
244232
)
245233
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
246-
self.runner.device, non_blocking=True)
234+
self.device, non_blocking=True)
247235
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
248-
self.runner.device, non_blocking=True)
249-
local_max_query_len = int(seqlens_q_local_np.max())
250-
local_max_seq_len = int(virt_k_seqlens_np.max())
236+
self.device, non_blocking=True)
237+
local_max_query_len = seqlens_q_local_np.max().item()
238+
local_max_seq_len = virt_k_seqlens_np.max().item()
251239
local_scheduler_metadata = schedule(
252240
batch_size=local_query_start_loc.shape[0] - 1,
253241
cu_query_lens=local_query_start_loc,
@@ -258,12 +246,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
258246

259247
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
260248
dtype=torch.int32,
261-
device=self.runner.device)
249+
device=self.device)
262250
local_cu_seq_lens[1:] = torch.cumsum(
263-
torch.from_numpy(virt_k_seqlens_np).to(
264-
device=self.runner.device,
265-
dtype=torch.int32,
266-
non_blocking=True),
251+
torch.from_numpy(virt_k_seqlens_np).to(device=self.device,
252+
dtype=torch.int32,
253+
non_blocking=True),
267254
dim=0)
268255

269256

0 commit comments

Comments
 (0)