|
16 | 16 | # See the License for the specific language governing permissions and
|
17 | 17 | # limitations under the License.
|
18 | 18 | #
|
19 |
| - |
20 |
| -import unittest |
21 |
| -from typing import ClassVar |
22 |
| - |
23 | 19 | import numpy as np
|
| 20 | +import pytest |
24 | 21 | import torch
|
25 | 22 | from compressed_tensors.transform.utils.hadamard import (
|
26 | 23 | deterministic_hadamard_matrix)
|
27 | 24 |
|
28 | 25 | from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
| 26 | +from vllm.platforms import current_platform |
29 | 27 | from vllm.qutlass_utils.utils import to_blocked
|
30 | 28 |
|
| 29 | +if not torch.cuda.is_available(): |
| 30 | + pytest.skip("CUDA required for these tests.", allow_module_level=True) |
| 31 | + |
| 32 | +if not (current_platform.has_device_capability(100) |
| 33 | + or current_platform.has_device_capability(120)): |
| 34 | + pytest.skip( |
| 35 | + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", |
| 36 | + allow_module_level=True, |
| 37 | + ) |
| 38 | + |
31 | 39 |
|
| 40 | +# ----- Helpers ----- |
32 | 41 | def get_hadamard_matrix(group_size: int, dtype: torch.dtype,
|
33 | 42 | device: torch.device):
|
34 | 43 | return (
|
@@ -176,141 +185,119 @@ def _forward_quantize_ref(x: torch.Tensor,
|
176 | 185 | )
|
177 | 186 |
|
178 | 187 |
|
179 |
| -@unittest.skipUnless(torch.cuda.is_available(), |
180 |
| - "CUDA required for these tests") |
181 |
| -class Test(unittest.TestCase): |
182 |
| - dtype: ClassVar[torch.dtype] |
183 |
| - device: ClassVar[torch.device] |
184 |
| - |
185 |
| - @classmethod |
186 |
| - def setUpClass(cls): |
187 |
| - seed = 0 |
188 |
| - np.random.seed(seed) |
189 |
| - torch.random.manual_seed(seed) |
190 |
| - cls.dtype = torch.bfloat16 |
191 |
| - cls.device = torch.device("cuda:0") |
192 |
| - |
193 |
| - def run_problem(self, m, n, k, had_size): |
194 |
| - print(m, n, k) |
195 |
| - hadamard_matrix = get_hadamard_matrix(had_size, self.dtype, |
196 |
| - self.device) |
197 |
| - |
198 |
| - a = torch.rand(m, k, dtype=self.dtype, device=self.device) * 25.0 |
199 |
| - b = torch.rand(n, k, dtype=self.dtype, device=self.device) * 25.0 |
200 |
| - |
201 |
| - a_e2m1, a_e8m0 = fusedQuantizeMx(a, hadamard_matrix, method="quest") |
202 |
| - b_e2m1, b_e8m0 = fusedQuantizeMx(b, hadamard_matrix, method="quest") |
203 |
| - |
204 |
| - a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) |
205 |
| - b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) |
206 |
| - out_ref = a_dq @ b_dq.transpose(-2, -1) |
207 |
| - |
208 |
| - a_scale_block = to_blocked(a_e8m0, True) |
209 |
| - b_scale_block = to_blocked(b_e8m0, True) |
210 |
| - alpha = torch.Tensor([1.0]).to(self.device) |
211 |
| - out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, |
212 |
| - alpha) |
213 |
| - assert out.equal(out_ref.to(dtype=out.dtype)) |
214 |
| - |
215 |
| - def test_fused_quantization(self): |
216 |
| - dtype, device = self.dtype, self.device |
217 |
| - |
218 |
| - def _absmax_case(rot_size: int): |
219 |
| - h = get_hadamard_matrix(rot_size, dtype, device) |
220 |
| - x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 |
221 |
| - |
222 |
| - xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, False) |
223 |
| - xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max") |
224 |
| - xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) # |
225 |
| - xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0) |
226 |
| - |
227 |
| - torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) |
228 |
| - assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 |
229 |
| - |
230 |
| - m, n, k = 1, 504, 4096 |
231 |
| - a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 |
232 |
| - b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 |
233 |
| - |
234 |
| - a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max") |
235 |
| - b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max") |
236 |
| - a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) |
237 |
| - b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) |
238 |
| - out_ref = a_dq @ b_dq.transpose(-2, -1) |
239 |
| - |
240 |
| - a_scale_block = to_blocked(a_e8m0) |
241 |
| - b_scale_block = to_blocked(b_e8m0) |
242 |
| - alpha = torch.Tensor([1.0]).to(device) |
243 |
| - out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, |
244 |
| - b_scale_block, alpha) |
245 |
| - assert out.equal(out_ref.to(dtype=out.dtype)) |
246 |
| - |
247 |
| - def _quest_case(rot_size: int): |
248 |
| - h = get_hadamard_matrix(rot_size, dtype, device) |
249 |
| - x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 |
250 |
| - |
251 |
| - xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, True) |
252 |
| - xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest") |
253 |
| - xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) # |
254 |
| - xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0) |
255 |
| - torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) |
256 |
| - assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 |
257 |
| - |
258 |
| - m, n, k = 504, 504, 2048 |
259 |
| - a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 |
260 |
| - b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 |
261 |
| - |
262 |
| - a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") |
263 |
| - b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") |
264 |
| - a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) |
265 |
| - b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) |
266 |
| - out_ref = a_dq @ b_dq.transpose(-2, -1) |
267 |
| - |
268 |
| - a_scale_block = to_blocked(a_e8m0, True) |
269 |
| - b_scale_block = to_blocked(b_e8m0, True) |
270 |
| - alpha = torch.Tensor([1.0]).to(device) |
271 |
| - out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, |
272 |
| - b_scale_block, alpha) |
273 |
| - assert out.equal(out_ref.to(dtype=out.dtype)) |
274 |
| - |
275 |
| - for rs in (32, 64, 128): |
276 |
| - _absmax_case(rs) |
277 |
| - for rs in (32, 64, 128): |
278 |
| - _quest_case(rs) |
279 |
| - |
280 |
| - def test_llama_shapes(self): |
281 |
| - print() |
282 |
| - MODELS = { |
283 |
| - " 7B": [ |
284 |
| - (4096, 3 * 4096), |
285 |
| - (4096, 4096), |
286 |
| - (4096, 2 * 10752), |
287 |
| - (10752, 4096), |
288 |
| - ], |
289 |
| - "13B": [ |
290 |
| - (5120, 3 * 5120), |
291 |
| - (5120, 5120), |
292 |
| - (5120, 2 * 13568), |
293 |
| - (13568, 5120), |
294 |
| - ], |
295 |
| - "33B": [ |
296 |
| - (6656, 3 * 6656), |
297 |
| - (6656, 6656), |
298 |
| - (6656, 2 * 17664), |
299 |
| - (17664, 6656), |
300 |
| - ], |
301 |
| - "70B": [ |
302 |
| - (8192, 3 * 8192), |
303 |
| - (8192, 8192), |
304 |
| - (8192, 2 * 21760), |
305 |
| - (21760, 8192), |
306 |
| - ], |
307 |
| - } |
308 |
| - for _, layers in MODELS.items(): |
309 |
| - for layer in layers: |
310 |
| - for batch in [1, 16]: |
311 |
| - for had_size in [32, 64, 128]: |
312 |
| - self.run_problem(batch, layer[1], layer[0], had_size) |
313 |
| - |
314 |
| - |
315 |
| -if __name__ == "__main__": |
316 |
| - unittest.main() |
| 188 | +DTYPE = torch.bfloat16 |
| 189 | +DEVICE = torch.device("cuda:0") |
| 190 | + |
| 191 | +ROT_SIZES = [32, 64, 128] |
| 192 | +SEEDS = [0] |
| 193 | +BATCHES = [1, 16] |
| 194 | + |
| 195 | +LLAMA_MODELS = { |
| 196 | + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], |
| 197 | + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], |
| 198 | + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], |
| 199 | + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], |
| 200 | +} |
| 201 | + |
| 202 | + |
| 203 | +@pytest.fixture(autouse=True) |
| 204 | +def _seed_each_test(): |
| 205 | + current_platform.seed_everything(0) |
| 206 | + np.random.seed(0) |
| 207 | + torch.random.manual_seed(0) |
| 208 | + |
| 209 | + |
| 210 | +@pytest.mark.parametrize("rot_size", ROT_SIZES) |
| 211 | +@torch.inference_mode() |
| 212 | +def test_fused_quantization_absmax(rot_size: int): |
| 213 | + dtype, device = DTYPE, DEVICE |
| 214 | + h = get_hadamard_matrix(rot_size, dtype, device) |
| 215 | + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 |
| 216 | + |
| 217 | + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False) |
| 218 | + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max") |
| 219 | + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) |
| 220 | + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0) |
| 221 | + |
| 222 | + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) |
| 223 | + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 |
| 224 | + |
| 225 | + m, n, k = 1, 504, 4096 |
| 226 | + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 |
| 227 | + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 |
| 228 | + |
| 229 | + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max") |
| 230 | + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max") |
| 231 | + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) |
| 232 | + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) |
| 233 | + out_ref = a_dq @ b_dq.transpose(-2, -1) |
| 234 | + |
| 235 | + a_scale_block = to_blocked(a_e8m0) |
| 236 | + b_scale_block = to_blocked(b_e8m0) |
| 237 | + alpha = torch.tensor([1.0], device=device) |
| 238 | + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, |
| 239 | + alpha) |
| 240 | + assert out.equal(out_ref.to(dtype=out.dtype)) |
| 241 | + |
| 242 | + |
| 243 | +@pytest.mark.parametrize("rot_size", ROT_SIZES) |
| 244 | +@torch.inference_mode() |
| 245 | +def test_fused_quantization_quest(rot_size: int): |
| 246 | + dtype, device = DTYPE, DEVICE |
| 247 | + h = get_hadamard_matrix(rot_size, dtype, device) |
| 248 | + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 |
| 249 | + |
| 250 | + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True) |
| 251 | + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest") |
| 252 | + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) |
| 253 | + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0) |
| 254 | + |
| 255 | + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) |
| 256 | + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 |
| 257 | + |
| 258 | + m, n, k = 504, 504, 2048 |
| 259 | + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 |
| 260 | + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 |
| 261 | + |
| 262 | + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") |
| 263 | + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") |
| 264 | + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) |
| 265 | + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) |
| 266 | + out_ref = a_dq @ b_dq.transpose(-2, -1) |
| 267 | + |
| 268 | + a_scale_block = to_blocked(a_e8m0, True) |
| 269 | + b_scale_block = to_blocked(b_e8m0, True) |
| 270 | + alpha = torch.tensor([1.0], device=device) |
| 271 | + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, |
| 272 | + alpha) |
| 273 | + assert out.equal(out_ref.to(dtype=out.dtype)) |
| 274 | + |
| 275 | + |
| 276 | +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) |
| 277 | +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) |
| 278 | +@pytest.mark.parametrize("batch", [1, 16]) |
| 279 | +@pytest.mark.parametrize("had_size", ROT_SIZES) |
| 280 | +@torch.inference_mode() |
| 281 | +def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int): |
| 282 | + dtype, device = DTYPE, DEVICE |
| 283 | + m = batch |
| 284 | + k, n = LLAMA_MODELS[model][layer_idx] |
| 285 | + |
| 286 | + h = get_hadamard_matrix(had_size, dtype, device) |
| 287 | + |
| 288 | + a = torch.rand(m, k, dtype=dtype, device=device) * 25.0 |
| 289 | + b = torch.rand(n, k, dtype=dtype, device=device) * 25.0 |
| 290 | + |
| 291 | + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") |
| 292 | + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") |
| 293 | + |
| 294 | + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) |
| 295 | + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) |
| 296 | + out_ref = a_dq @ b_dq.transpose(-2, -1) |
| 297 | + |
| 298 | + a_scale_block = to_blocked(a_e8m0, True) |
| 299 | + b_scale_block = to_blocked(b_e8m0, True) |
| 300 | + alpha = torch.tensor([1.0], device=device) |
| 301 | + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, |
| 302 | + alpha) |
| 303 | + assert out.equal(out_ref.to(dtype=out.dtype)) |
0 commit comments