Skip to content

Commit aedf9c4

Browse files
author
Jingchun Gao
committed
[Lint] clean code
Signed-off-by: Jingchun Gao <[email protected]>
1 parent bd65197 commit aedf9c4

File tree

4 files changed

+39
-36
lines changed

4 files changed

+39
-36
lines changed

tests/compile/test_fusions_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class ModelBackendTestCase(NamedTuple):
7474
ModelBackendTestCase(
7575
model_name="Qwen/Qwen3-30B-A3B",
7676
model_kwargs=dict(max_model_len=1024),
77-
backend=_Backend.TRITON_ATTN,
77+
backend=AttentionBackendEnum.TRITON_ATTN,
7878
attention_fusions=0,
7979
allreduce_fusions=97,
8080
),

vllm/utils/flashinfer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,7 @@ def use_trtllm_attention(
265265
# Decode context parallel is not supported
266266
if dcp_world_size > 1:
267267
logger.warning_once(
268-
"Trtllm not support lse, please use flash attention "
269-
"or FlashInfer backend."
268+
"Trtllm not support lse, please use flash attention or FlashInfer backend."
270269
)
271270
return False
272271

vllm/v1/attention/backends/flashinfer.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from vllm.attention.ops.common import cp_lse_ag_out_rs
2727
from vllm.attention.ops.merge_attn_states import merge_attn_states
2828
from vllm.config import CUDAGraphMode, VllmConfig
29-
from vllm.distributed.parallel_state import get_dcp_group
3029
from vllm.config.cache import CacheDType
30+
from vllm.distributed.parallel_state import get_dcp_group
3131
from vllm.logger import init_logger
3232
from vllm.model_executor.layers.batch_invariant import (
3333
vllm_is_batch_invariant,
@@ -51,8 +51,8 @@
5151
AttentionCGSupport,
5252
AttentionMetadataBuilder,
5353
CommonAttentionMetadata,
54-
get_dcp_local_seq_lens,
5554
KVCacheLayoutType,
55+
get_dcp_local_seq_lens,
5656
get_kv_cache_layout,
5757
get_per_layer_parameters,
5858
infer_global_hyperparameters,
@@ -163,9 +163,11 @@ def trtllm_prefill_attn_kvfp8_dequant(
163163
)
164164
return mock_kv_cache, mock_block_table
165165

166+
166167
@dataclass
167168
class BatchDCPPrefillPlanConfig:
168169
"""Parameters for BatchDCPPrefillWrapper.plan() method."""
170+
169171
qo_indptr_cpu: torch.Tensor
170172
paged_kv_indptr_cpu: torch.Tensor
171173
paged_kv_indices: torch.Tensor
@@ -203,7 +205,7 @@ def plan(self, cfg: BatchDCPPrefillPlanConfig):
203205
cfg.qo_indptr_cpu,
204206
cfg.paged_kv_indptr_cpu,
205207
cfg.paged_kv_indices,
206-
cfg.paged_kv_last_page_len_cpu[cfg.prefill_start:],
208+
cfg.paged_kv_last_page_len_cpu[cfg.prefill_start :],
207209
cfg.num_qo_heads * cfg.dcp_world_size,
208210
cfg.num_kv_heads,
209211
cfg.head_dim,
@@ -230,15 +232,15 @@ def plan(self, cfg: BatchDCPPrefillPlanConfig):
230232
logits_soft_cap=cfg.logits_soft_cap,
231233
q_data_type=cfg.q_data_type,
232234
)
233-
235+
234236
def run(
235-
self,
236-
layer: torch.nn.Module,
237-
prefill_query: torch.Tensor,
238-
kv_cache_permute: torch.Tensor,
239-
key: torch.Tensor,
240-
value: torch.Tensor,
241-
):
237+
self,
238+
layer: torch.nn.Module,
239+
prefill_query: torch.Tensor,
240+
kv_cache_permute: torch.Tensor,
241+
key: torch.Tensor,
242+
value: torch.Tensor,
243+
):
242244
prefill_query_across_dcp = get_dcp_group().all_gather(
243245
prefill_query.contiguous(), dim=1
244246
)
@@ -272,6 +274,7 @@ def run(
272274
)
273275
return output
274276

277+
275278
class FlashInferBackend(AttentionBackend):
276279
accept_output_buffer: bool = True
277280
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@@ -379,9 +382,7 @@ class FlashInferMetadata:
379382
use_cascade: bool
380383

381384
prefill_wrapper: (
382-
BatchPrefillWithPagedKVCacheWrapper
383-
| BatchDCPPrefillWrapper
384-
| None
385+
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
385386
) = None
386387
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
387388
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
@@ -409,13 +410,7 @@ def __init__(
409410
self.model_config = vllm_config.model_config
410411
self._workspace_buffer = None
411412
self._prefill_wrapper: (
412-
BatchPrefillWithPagedKVCacheWrapper
413-
| dict[
414-
str,
415-
BatchPrefillWithPagedKVCacheWrapper
416-
| BatchPrefillWithRaggedKVCacheWrapper,
417-
]
418-
| None
413+
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
419414
) = None # Wrapper for prefill/append
420415
self._decode_wrapper = None # Wrapper for decode (general shape)
421416

@@ -563,13 +558,7 @@ def _get_workspace_buffer(self):
563558

564559
def _get_prefill_wrapper(
565560
self,
566-
) -> (
567-
BatchPrefillWithPagedKVCacheWrapper
568-
| dict[
569-
str,
570-
BatchPrefillWithPagedKVCacheWrapper | BatchPrefillWithRaggedKVCacheWrapper,
571-
]
572-
):
561+
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
573562
if self._prefill_wrapper is None:
574563
if self.dcp_world_size > 1:
575564
self._prefill_wrapper = BatchDCPPrefillWrapper(
@@ -855,7 +844,9 @@ def build(
855844

856845
if not attn_metadata.prefill_use_trtllm:
857846
if self.dcp_world_size > 1:
858-
assert isinstance(attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper)
847+
assert isinstance(
848+
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
849+
)
859850
plan_cfgs = BatchDCPPrefillPlanConfig(
860851
qo_indptr_cpu=qo_indptr_cpu,
861852
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
@@ -877,6 +868,10 @@ def build(
877868
)
878869
attn_metadata.prefill_wrapper.plan(plan_cfgs)
879870
else:
871+
assert isinstance(
872+
attn_metadata.prefill_wrapper,
873+
BatchPrefillWithPagedKVCacheWrapper,
874+
)
880875
attn_metadata.prefill_wrapper.plan(
881876
qo_indptr_cpu,
882877
paged_kv_indptr_cpu,
@@ -1196,11 +1191,15 @@ def forward(
11961191
if self.dcp_world_size > 1:
11971192
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
11981193
assert prefill_wrapper._context._window_left == self.window_left
1199-
assert prefill_wrapper._context._logits_soft_cap == (self.logits_soft_cap or 0.0)
1194+
assert prefill_wrapper._context._logits_soft_cap == (
1195+
self.logits_soft_cap or 0.0
1196+
)
12001197
assert prefill_wrapper._context._sm_scale == self.scale
12011198
assert not prefill_wrapper._context.causal
12021199
assert prefill_wrapper._new_tokens._window_left == self.window_left
1203-
assert prefill_wrapper._new_tokens._logits_soft_cap == (self.logits_soft_cap or 0.0)
1200+
assert prefill_wrapper._new_tokens._logits_soft_cap == (
1201+
self.logits_soft_cap or 0.0
1202+
)
12041203
assert prefill_wrapper._new_tokens._sm_scale == self.scale
12051204
assert prefill_wrapper._new_tokens.causal
12061205

@@ -1212,8 +1211,13 @@ def forward(
12121211
value[num_decode_tokens:],
12131212
)
12141213
else:
1214+
assert isinstance(
1215+
prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
1216+
)
12151217
assert prefill_wrapper._window_left == self.window_left
1216-
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1218+
assert prefill_wrapper._logits_soft_cap == (
1219+
self.logits_soft_cap or 0.0
1220+
)
12171221
assert prefill_wrapper._sm_scale == self.scale
12181222
assert prefill_wrapper._causal
12191223
prefill_wrapper.run(

vllm/v1/executor/multiproc_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
3232
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
3333
from vllm.distributed.parallel_state import (
34+
get_dcp_group,
3435
get_dp_group,
3536
get_ep_group,
3637
get_pp_group,
3738
get_tp_group,
38-
get_dcp_group,
3939
)
4040
from vllm.envs import enable_envs_cache
4141
from vllm.logger import init_logger

0 commit comments

Comments
 (0)