diff --git a/tests/attention/test_fp8_prefill.py b/tests/attention/test_fp8_prefill.py index 414173f452..1b8ebc75cc 100644 --- a/tests/attention/test_fp8_prefill.py +++ b/tests/attention/test_fp8_prefill.py @@ -66,7 +66,7 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) wrapper_f16 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) wrapper_f16.plan( qo_indptr, @@ -90,7 +90,7 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( kv_data_fp8 = torch.cat([k_fp8, v_fp8], dim=1) wrapper_f8 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) wrapper_f8.plan( qo_indptr, @@ -156,7 +156,7 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) wrapper.plan( qo_indptr, @@ -173,7 +173,7 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( o_fp8 = wrapper.run(q, kv_data) decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) decode_wrapper.plan( kv_indptr,