1313from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1414from vllm .v1 .kv_cache_interface import FullAttentionSpec
1515
16+ BACKENDS_TO_TEST = ["flash_attn" , "flashinfer" , "flex_attention" ]
17+
18+ # Remove flashinfer from the list if it's not available
19+ try :
20+ import flashinfer # noqa: F401
21+ except ImportError :
22+ BACKENDS_TO_TEST .remove ("flashinfer" )
23+
1624
1725def _convert_dtype_to_torch (dtype ):
1826 """Convert ModelDType to torch.dtype."""
@@ -84,6 +92,9 @@ def __init__(self):
8492 self ._q_scale = torch .tensor (1.0 )
8593 self ._k_scale = torch .tensor (1.0 )
8694 self ._v_scale = torch .tensor (1.0 )
95+ # Add float versions for flashinfer
96+ self ._k_scale_float = 1.0
97+ self ._v_scale_float = 1.0
8798
8899
89100def run_attention_backend (backend_name : str , kv_cache_spec : FullAttentionSpec ,
@@ -96,22 +107,52 @@ def run_attention_backend(backend_name: str, kv_cache_spec: FullAttentionSpec,
96107
97108 builder_cls , impl_cls = get_attention_backend (backend_name )
98109
99- # Build metadata
100- builder = builder_cls (kv_cache_spec , vllm_config , device )
101- attn_metadata = builder .build (
102- common_prefix_len = 0 ,
103- common_attn_metadata = common_attn_metadata ,
104- )
110+ # Mock flashinfer's get_per_layer_parameters if needed
111+ if backend_name == "flashinfer" :
112+ import unittest .mock
113+
114+ from vllm .v1 .attention .backends .flashinfer import PerLayerParameters
115+
116+ def mock_get_per_layer_parameters (vllm_config ):
117+ # Return mock parameters for a single layer
118+ head_size = vllm_config .model_config .get_head_size ()
119+ return {
120+ "mock_layer" :
121+ PerLayerParameters (
122+ window_left = - 1 , # No sliding window
123+ logits_soft_cap = 0.0 , # No soft cap
124+ sm_scale = 1.0 / (head_size ** 0.5 ) # Standard scale
125+ )
126+ }
127+
128+ with unittest .mock .patch (
129+ 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters' ,
130+ mock_get_per_layer_parameters ):
131+ builder = builder_cls (kv_cache_spec , vllm_config , device )
132+ attn_metadata = builder .build (
133+ common_prefix_len = 0 ,
134+ common_attn_metadata = common_attn_metadata ,
135+ )
136+ else :
137+ # Build metadata
138+ builder = builder_cls (kv_cache_spec , vllm_config , device )
139+ attn_metadata = builder .build (
140+ common_prefix_len = 0 ,
141+ common_attn_metadata = common_attn_metadata ,
142+ )
105143
106144 # Instantiate implementation
107- num_heads = kv_cache_spec .num_kv_heads
108- head_size = kv_cache_spec .head_size
145+ num_heads = vllm_config .model_config .get_num_attention_heads (
146+ vllm_config .parallel_config )
147+ num_kv_heads = vllm_config .model_config .get_num_kv_heads (
148+ vllm_config .parallel_config )
149+ head_size = vllm_config .model_config .get_head_size ()
109150 scale = 1.0 / (head_size ** 0.5 )
110151 impl = impl_cls (
111152 num_heads = num_heads ,
112153 head_size = head_size ,
113154 scale = scale ,
114- num_kv_heads = num_heads ,
155+ num_kv_heads = num_kv_heads ,
115156 alibi_slopes = None ,
116157 sliding_window = None ,
117158 kv_cache_dtype = "auto" ,
@@ -255,6 +296,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
255296 batch_spec , vllm_config .cache_config .block_size , device )
256297
257298 # 3. Simulate Paged KV Cache and a realistic slot_mapping
299+ # Note: In vLLM, block_id=0 is reserved as the null block and should not
300+ # be used
258301 block_table = common_attn_metadata .block_table_tensor
259302 num_blocks = vllm_config .cache_config .num_gpu_blocks or 1000
260303 kv_cache = torch .empty (2 ,
@@ -267,7 +310,9 @@ def test_backend_correctness(batch_spec_name: str, model: str):
267310 kv_cache_flat = kv_cache .view (2 , - 1 , num_kv_heads , head_size )
268311
269312 # Populate the cache with the context tokens
270- start_block_idx = 0
313+ # Start from block_id=1 since block_id=0 is considered the null block in
314+ # vLLM
315+ start_block_idx = 1
271316 for i in range (batch_size ):
272317 k_context , v_context = all_k_context [i ], all_v_context [i ]
273318 start = start_block_idx * block_size
@@ -279,13 +324,18 @@ def test_backend_correctness(batch_spec_name: str, model: str):
279324 start_block_idx += cdiv (seq_lens [i ], block_size )
280325
281326 blocks_end = start_block_idx
282- # randomly permute the context blocks
283- perm = torch .arange (blocks_end ) #torch.randperm(blocks_end)
284- inv_perm = torch .argsort (perm )
285- kv_cache = kv_cache [:, perm , ...]
327+ # randomly permute the context blocks (excluding block 0 which is null)
328+ perm = torch .randperm (blocks_end -
329+ 1 ) + 1 # Random permutation starting from block 1
330+ inv_perm = torch .zeros (blocks_end , dtype = torch .long , device = device )
331+ inv_perm [1 :] = torch .argsort (
332+ perm ) + 1 # Add 1 to account for starting from block 1
333+ kv_cache [:, 1 :blocks_end , ...] = kv_cache [:, perm , ...]
286334
287335 # Construct the right block table
288- start_block_idx = 0
336+ # Start from block_id=1 since block_id=0 is considered the null block in
337+ # vLLM
338+ start_block_idx = 1
289339 for i in range (batch_size ):
290340 num_blocks = cdiv (seq_lens [i ], block_size )
291341 start = start_block_idx
@@ -306,14 +356,23 @@ def test_backend_correctness(batch_spec_name: str, model: str):
306356
307357 # 4. Run vLLM backends and compare
308358 # Note: flex_attention has known Triton kernel compatibility issues
309- # with test infrastructure
310- backends_to_test = ["flash_attn" ] # flex_attention has compilation issues
311- for backend_name in backends_to_test :
359+ # with test infrastructures
360+ for backend_name in BACKENDS_TO_TEST :
361+ # FlashAttentionm + FlexAttention:
362+ # [2, num_blocks, block_size, num_kv_heads, head_size]
363+ # FlashInfer:
364+ # [num_blocks, 2, block_size, num_kv_heads, head_size]
365+ # Select the appropriate KV cache format for each backend
366+ kv_cache_for_backend = kv_cache
367+ if backend_name == "flashinfer" :
368+ kv_cache_for_backend = kv_cache .transpose (0 , 1 )
369+
312370 backend_output = run_attention_backend (backend_name , kv_cache_spec ,
313371 vllm_config , device ,
314372 common_attn_metadata ,
315373 query_vllm , key_vllm ,
316- value_vllm , kv_cache )
374+ value_vllm ,
375+ kv_cache_for_backend )
317376
318377 # Check shape and dtype consistency
319378 assert backend_output .shape == sdpa_output .shape , (
@@ -330,6 +389,14 @@ def test_backend_correctness(batch_spec_name: str, model: str):
330389 rtol = 1e-5 if backend_output .dtype == torch .float32 else 1e-2
331390 atol = 1e-4 if backend_output .dtype == torch .float32 else 1e-3
332391
392+ # Flashinfer may have slightly different numerical behavior
393+ if backend_name == "flashinfer" :
394+ atol = 1e-3 if backend_output .dtype == torch .float32 else 5e-3
395+
396+ # Flex_attention may have slightly different numerical behavior
397+ if backend_name == "flex_attention" :
398+ atol = 1e-2 if backend_output .dtype == torch .float32 else 1e-2
399+
333400 max_diff = torch .max (torch .abs (backend_output - sdpa_output )).item ()
334401 assert torch .allclose (
335402 backend_output , sdpa_output , rtol = rtol , atol = atol ), (
0 commit comments