Skip to content

Commit b087694

Browse files
format
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 701fdc0 commit b087694

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
create_common_attn_metadata,
1010
create_standard_kv_cache_spec,
1111
create_vllm_config,
12-
get_attention_backend,)
12+
get_attention_backend)
1313
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
14-
from vllm.v1.attention.backends.utils import (
15-
CommonAttentionMetadata, set_kv_cache_layout)
14+
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
15+
set_kv_cache_layout)
1616
from vllm.v1.kv_cache_interface import FullAttentionSpec
1717

1818
BACKENDS_TO_TEST = [
@@ -421,8 +421,9 @@ def test_backend_correctness(batch_spec_name: str, model: str):
421421
if backend_name == _Backend.FLASHINFER_VLLM_V1:
422422
kv_cache_for_backend = kv_cache.transpose(0, 1)
423423

424-
# For FlashInfer default to HND layout and
425-
kv_cache_for_backend = kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
424+
# For FlashInfer default to HND layout and
425+
kv_cache_for_backend = kv_cache_for_backend.transpose(
426+
2, 3).contiguous().transpose(2, 3)
426427
set_kv_cache_layout("HND")
427428

428429
backend_output = run_attention_backend(backend_name, kv_cache_spec,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from typing import TYPE_CHECKING, Optional
88

99
import torch
10+
11+
import vllm.envs as envs
1012
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1113
BatchPrefillWithPagedKVCacheWrapper,
1214
MultiLevelCascadeAttentionWrapper)
1315
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
14-
15-
import vllm.envs as envs
1616
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1717
AttentionType)
1818
from vllm.config import VllmConfig
@@ -673,7 +673,6 @@ def forward(
673673
assert block_tables_decode.is_contiguous()
674674
assert seq_lens_decode.is_contiguous()
675675

676-
677676
output[:num_decode_tokens] = (
678677
trtllm_batch_decode_with_kv_cache(
679678
query=decode_query,

0 commit comments

Comments
 (0)