Skip to content

Commit d0805ae

Browse files
use pytest and cuda version check
Signed-off-by: LopezCastroRoberto <[email protected]>
1 parent c991133 commit d0805ae

File tree

5 files changed

+248
-272
lines changed

5 files changed

+248
-272
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,8 @@ steps:
800800
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
801801
- pytest -v -s tests/kernels/moe/test_flashinfer.py
802802
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
803+
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
804+
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
803805

804806
- label: GPT-OSS Eval (Blackwell)
805807
timeout_in_minutes: 60

cmake/external_projects/qutlass.cmake

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ endif()
3232
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
3333

3434
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
35-
if(QUTLASS_ARCHS)
35+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
3636

3737
if(QUTLASS_ARCHS MATCHES "12\\.0a")
3838
set(QUTLASS_TARGET_CC 120)
@@ -82,5 +82,12 @@ if(QUTLASS_ARCHS)
8282
)
8383

8484
else()
85-
message(STATUS "[QUTLASS] skipping (no 12.0a/10.0a in CUDA_ARCHS='${CUDA_ARCHS}')")
85+
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
86+
message(STATUS
87+
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
88+
else()
89+
message(STATUS
90+
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
91+
"CUDA_ARCHS='${CUDA_ARCHS}'.")
92+
endif()
8693
endif()

tests/kernels/quantization/test_mxfp4_qutlass.py

Lines changed: 129 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,28 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818
#
19-
20-
import unittest
21-
from typing import ClassVar
22-
2319
import numpy as np
20+
import pytest
2421
import torch
2522
from compressed_tensors.transform.utils.hadamard import (
2623
deterministic_hadamard_matrix)
2724

2825
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
26+
from vllm.platforms import current_platform
2927
from vllm.qutlass_utils.utils import to_blocked
3028

29+
if not torch.cuda.is_available():
30+
pytest.skip("CUDA required for these tests.", allow_module_level=True)
31+
32+
if not (current_platform.has_device_capability(100)
33+
or current_platform.has_device_capability(120)):
34+
pytest.skip(
35+
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
36+
allow_module_level=True,
37+
)
38+
3139

40+
# ----- Helpers -----
3241
def get_hadamard_matrix(group_size: int, dtype: torch.dtype,
3342
device: torch.device):
3443
return (
@@ -176,141 +185,119 @@ def _forward_quantize_ref(x: torch.Tensor,
176185
)
177186

178187

179-
@unittest.skipUnless(torch.cuda.is_available(),
180-
"CUDA required for these tests")
181-
class Test(unittest.TestCase):
182-
dtype: ClassVar[torch.dtype]
183-
device: ClassVar[torch.device]
184-
185-
@classmethod
186-
def setUpClass(cls):
187-
seed = 0
188-
np.random.seed(seed)
189-
torch.random.manual_seed(seed)
190-
cls.dtype = torch.bfloat16
191-
cls.device = torch.device("cuda:0")
192-
193-
def run_problem(self, m, n, k, had_size):
194-
print(m, n, k)
195-
hadamard_matrix = get_hadamard_matrix(had_size, self.dtype,
196-
self.device)
197-
198-
a = torch.rand(m, k, dtype=self.dtype, device=self.device) * 25.0
199-
b = torch.rand(n, k, dtype=self.dtype, device=self.device) * 25.0
200-
201-
a_e2m1, a_e8m0 = fusedQuantizeMx(a, hadamard_matrix, method="quest")
202-
b_e2m1, b_e8m0 = fusedQuantizeMx(b, hadamard_matrix, method="quest")
203-
204-
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
205-
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
206-
out_ref = a_dq @ b_dq.transpose(-2, -1)
207-
208-
a_scale_block = to_blocked(a_e8m0, True)
209-
b_scale_block = to_blocked(b_e8m0, True)
210-
alpha = torch.Tensor([1.0]).to(self.device)
211-
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block,
212-
alpha)
213-
assert out.equal(out_ref.to(dtype=out.dtype))
214-
215-
def test_fused_quantization(self):
216-
dtype, device = self.dtype, self.device
217-
218-
def _absmax_case(rot_size: int):
219-
h = get_hadamard_matrix(rot_size, dtype, device)
220-
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
221-
222-
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, False)
223-
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max")
224-
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) #
225-
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0)
226-
227-
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
228-
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
229-
230-
m, n, k = 1, 504, 4096
231-
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
232-
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
233-
234-
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max")
235-
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max")
236-
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
237-
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
238-
out_ref = a_dq @ b_dq.transpose(-2, -1)
239-
240-
a_scale_block = to_blocked(a_e8m0)
241-
b_scale_block = to_blocked(b_e8m0)
242-
alpha = torch.Tensor([1.0]).to(device)
243-
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block,
244-
b_scale_block, alpha)
245-
assert out.equal(out_ref.to(dtype=out.dtype))
246-
247-
def _quest_case(rot_size: int):
248-
h = get_hadamard_matrix(rot_size, dtype, device)
249-
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
250-
251-
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, True)
252-
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest")
253-
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) #
254-
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0)
255-
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
256-
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
257-
258-
m, n, k = 504, 504, 2048
259-
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
260-
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
261-
262-
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
263-
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
264-
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
265-
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
266-
out_ref = a_dq @ b_dq.transpose(-2, -1)
267-
268-
a_scale_block = to_blocked(a_e8m0, True)
269-
b_scale_block = to_blocked(b_e8m0, True)
270-
alpha = torch.Tensor([1.0]).to(device)
271-
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block,
272-
b_scale_block, alpha)
273-
assert out.equal(out_ref.to(dtype=out.dtype))
274-
275-
for rs in (32, 64, 128):
276-
_absmax_case(rs)
277-
for rs in (32, 64, 128):
278-
_quest_case(rs)
279-
280-
def test_llama_shapes(self):
281-
print()
282-
MODELS = {
283-
" 7B": [
284-
(4096, 3 * 4096),
285-
(4096, 4096),
286-
(4096, 2 * 10752),
287-
(10752, 4096),
288-
],
289-
"13B": [
290-
(5120, 3 * 5120),
291-
(5120, 5120),
292-
(5120, 2 * 13568),
293-
(13568, 5120),
294-
],
295-
"33B": [
296-
(6656, 3 * 6656),
297-
(6656, 6656),
298-
(6656, 2 * 17664),
299-
(17664, 6656),
300-
],
301-
"70B": [
302-
(8192, 3 * 8192),
303-
(8192, 8192),
304-
(8192, 2 * 21760),
305-
(21760, 8192),
306-
],
307-
}
308-
for _, layers in MODELS.items():
309-
for layer in layers:
310-
for batch in [1, 16]:
311-
for had_size in [32, 64, 128]:
312-
self.run_problem(batch, layer[1], layer[0], had_size)
313-
314-
315-
if __name__ == "__main__":
316-
unittest.main()
188+
DTYPE = torch.bfloat16
189+
DEVICE = torch.device("cuda:0")
190+
191+
ROT_SIZES = [32, 64, 128]
192+
SEEDS = [0]
193+
BATCHES = [1, 16]
194+
195+
LLAMA_MODELS = {
196+
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
197+
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
198+
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
199+
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
200+
}
201+
202+
203+
@pytest.fixture(autouse=True)
204+
def _seed_each_test():
205+
current_platform.seed_everything(0)
206+
np.random.seed(0)
207+
torch.random.manual_seed(0)
208+
209+
210+
@pytest.mark.parametrize("rot_size", ROT_SIZES)
211+
@torch.inference_mode()
212+
def test_fused_quantization_absmax(rot_size: int):
213+
dtype, device = DTYPE, DEVICE
214+
h = get_hadamard_matrix(rot_size, dtype, device)
215+
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
216+
217+
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False)
218+
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max")
219+
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
220+
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0)
221+
222+
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
223+
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
224+
225+
m, n, k = 1, 504, 4096
226+
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
227+
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
228+
229+
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max")
230+
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max")
231+
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
232+
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
233+
out_ref = a_dq @ b_dq.transpose(-2, -1)
234+
235+
a_scale_block = to_blocked(a_e8m0)
236+
b_scale_block = to_blocked(b_e8m0)
237+
alpha = torch.tensor([1.0], device=device)
238+
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block,
239+
alpha)
240+
assert out.equal(out_ref.to(dtype=out.dtype))
241+
242+
243+
@pytest.mark.parametrize("rot_size", ROT_SIZES)
244+
@torch.inference_mode()
245+
def test_fused_quantization_quest(rot_size: int):
246+
dtype, device = DTYPE, DEVICE
247+
h = get_hadamard_matrix(rot_size, dtype, device)
248+
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
249+
250+
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True)
251+
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest")
252+
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
253+
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0)
254+
255+
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
256+
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
257+
258+
m, n, k = 504, 504, 2048
259+
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
260+
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
261+
262+
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
263+
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
264+
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
265+
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
266+
out_ref = a_dq @ b_dq.transpose(-2, -1)
267+
268+
a_scale_block = to_blocked(a_e8m0, True)
269+
b_scale_block = to_blocked(b_e8m0, True)
270+
alpha = torch.tensor([1.0], device=device)
271+
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block,
272+
alpha)
273+
assert out.equal(out_ref.to(dtype=out.dtype))
274+
275+
276+
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
277+
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
278+
@pytest.mark.parametrize("batch", [1, 16])
279+
@pytest.mark.parametrize("had_size", ROT_SIZES)
280+
@torch.inference_mode()
281+
def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int):
282+
dtype, device = DTYPE, DEVICE
283+
m = batch
284+
k, n = LLAMA_MODELS[model][layer_idx]
285+
286+
h = get_hadamard_matrix(had_size, dtype, device)
287+
288+
a = torch.rand(m, k, dtype=dtype, device=device) * 25.0
289+
b = torch.rand(n, k, dtype=dtype, device=device) * 25.0
290+
291+
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
292+
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
293+
294+
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
295+
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
296+
out_ref = a_dq @ b_dq.transpose(-2, -1)
297+
298+
a_scale_block = to_blocked(a_e8m0, True)
299+
b_scale_block = to_blocked(b_e8m0, True)
300+
alpha = torch.tensor([1.0], device=device)
301+
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block,
302+
alpha)
303+
assert out.equal(out_ref.to(dtype=out.dtype))

0 commit comments

Comments
 (0)