|
9 | 9 | from vllm.vllm_flash_attn import ( |
10 | 10 | fa_version_unsupported_reason, |
11 | 11 | flash_attn_varlen_func, |
12 | | - flash_attn_with_kvcache, |
13 | 12 | is_fa_version_supported, |
14 | 13 | ) |
15 | 14 |
|
@@ -83,124 +82,6 @@ def ref_paged_attn( |
83 | 82 | return torch.cat(outputs, dim=0) |
84 | 83 |
|
85 | 84 |
|
86 | | -@pytest.mark.parametrize("use_out", [True, False]) |
87 | | -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) |
88 | | -@pytest.mark.parametrize("num_heads", NUM_HEADS) |
89 | | -@pytest.mark.parametrize("head_size", HEAD_SIZES) |
90 | | -@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
91 | | -@pytest.mark.parametrize("dtype", DTYPES) |
92 | | -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) |
93 | | -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
94 | | -@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) |
95 | | -@pytest.mark.parametrize("fa_version", [2, 3]) |
96 | | -@pytest.mark.parametrize("q_dtype", QDTYPES) |
97 | | -@torch.inference_mode() |
98 | | -def test_flash_attn_with_paged_kv( |
99 | | - use_out: bool, |
100 | | - kv_lens: list[int], |
101 | | - num_heads: tuple[int, int], |
102 | | - head_size: int, |
103 | | - dtype: torch.dtype, |
104 | | - block_size: int, |
105 | | - soft_cap: float | None, |
106 | | - num_blocks: int, |
107 | | - sliding_window: int | None, |
108 | | - fa_version: int, |
109 | | - q_dtype: torch.dtype | None, |
110 | | -) -> None: |
111 | | - torch.set_default_device("cuda") |
112 | | - if not is_fa_version_supported(fa_version): |
113 | | - pytest.skip( |
114 | | - f"Flash attention version {fa_version} not supported due " |
115 | | - f'to: "{fa_version_unsupported_reason(fa_version)}"' |
116 | | - ) |
117 | | - if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): |
118 | | - pytest.skip( |
119 | | - "Flash attention with quantized inputs is only " |
120 | | - "supported on version 3 with bfloat16 base type" |
121 | | - ) |
122 | | - |
123 | | - current_platform.seed_everything(0) |
124 | | - num_seqs = len(kv_lens) |
125 | | - num_query_heads = num_heads[0] |
126 | | - num_kv_heads = num_heads[1] |
127 | | - assert num_query_heads % num_kv_heads == 0 |
128 | | - max_kv_len = max(kv_lens) |
129 | | - scale = head_size**-0.5 |
130 | | - window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) |
131 | | - |
132 | | - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) |
133 | | - key_cache = torch.randn( |
134 | | - num_blocks, block_size, num_kv_heads, head_size, dtype=dtype |
135 | | - ) |
136 | | - value_cache = torch.randn_like(key_cache) |
137 | | - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) |
138 | | - |
139 | | - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size |
140 | | - block_tables = torch.randint( |
141 | | - 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 |
142 | | - ) |
143 | | - |
144 | | - q = query.unsqueeze(1) |
145 | | - out = torch.empty_like(q) if use_out else None |
146 | | - |
147 | | - maybe_quantized_query = q |
148 | | - maybe_quantized_key_cache = key_cache |
149 | | - maybe_quantized_value_cache = value_cache |
150 | | - q_descale = None |
151 | | - k_descale = None |
152 | | - v_descale = None |
153 | | - if q_dtype is not None: |
154 | | - # QKV are drawn from N(0, 1): no need for a fp8 scaling factor |
155 | | - maybe_quantized_query = q.to(q_dtype) |
156 | | - maybe_quantized_key_cache = key_cache.to(q_dtype) |
157 | | - maybe_quantized_value_cache = value_cache.to(q_dtype) |
158 | | - |
159 | | - scale_shape = (num_seqs, num_kv_heads) |
160 | | - q_descale = torch.ones(scale_shape, dtype=torch.float32) |
161 | | - k_descale = torch.ones(scale_shape, dtype=torch.float32) |
162 | | - v_descale = torch.ones(scale_shape, dtype=torch.float32) |
163 | | - |
164 | | - output = flash_attn_with_kvcache( |
165 | | - q=maybe_quantized_query, |
166 | | - k_cache=maybe_quantized_key_cache, |
167 | | - v_cache=maybe_quantized_value_cache, |
168 | | - out=out, |
169 | | - softmax_scale=scale, |
170 | | - causal=True, |
171 | | - block_table=block_tables, |
172 | | - cache_seqlens=kv_lens_tensor, |
173 | | - softcap=soft_cap if soft_cap is not None else 0, |
174 | | - window_size=window_size, |
175 | | - fa_version=fa_version, |
176 | | - q_descale=q_descale, |
177 | | - k_descale=k_descale, |
178 | | - v_descale=v_descale, |
179 | | - ) |
180 | | - output = output if not use_out else out |
181 | | - output = output.squeeze(1) |
182 | | - |
183 | | - atol, rtol = 1.5e-2, 1e-2 |
184 | | - if q_dtype is not None: |
185 | | - atol, rtol = 1.5e-1, 1.5e-1 |
186 | | - |
187 | | - ref_output = ref_paged_attn( |
188 | | - query=query, |
189 | | - key_cache=key_cache, |
190 | | - value_cache=value_cache, |
191 | | - query_lens=[1] * num_seqs, |
192 | | - kv_lens=kv_lens, |
193 | | - block_tables=block_tables, |
194 | | - scale=scale, |
195 | | - soft_cap=soft_cap, |
196 | | - sliding_window=sliding_window, |
197 | | - ) |
198 | | - ( |
199 | | - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), |
200 | | - f"{torch.max(torch.abs(output - ref_output))}", |
201 | | - ) |
202 | | - |
203 | | - |
204 | 85 | @pytest.mark.parametrize("use_out", [True, False]) |
205 | 86 | @pytest.mark.parametrize( |
206 | 87 | "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] |
|
0 commit comments