diff --git a/dev-requirements.txt b/dev-requirements.txt index 8a8ed1e491..76a984d939 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,10 +1,15 @@ -pytest +# Test utilities +pytest==7.4.0 expecttest +unittest-xml-reporting parameterized packaging transformers + +# For prototype features and benchmarks bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers -matplotlib # needed for triton benchmarking -pandas # also for triton benchmarking -transformers #for galore testing +matplotlib +pandas + +# Custom CUDA Extensions ninja diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 23f6c60f70..471ede4250 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,17 +1,11 @@ # Skip entire test if triton is not available, otherwise CI failure import pytest -try: - import triton - import hqq - if int(triton.__version__.split(".")[0]) < 3: - pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True) -except ImportError: - pytest.skip("triton and hqq required to run this test", allow_module_level=True) - -import itertools -import torch - -from hqq.core.quantize import HQQLinear, BaseQuantizeConfig + +triton = pytest.importorskip("triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test") +hqq = pytest.importorskip("hqq", reason="hqq required to run this test") +HQQLinear = pytest.importorskip("hqq.core.quantize.HQQLinear", reason="HQQLinear required to run this test") +BaseQuantizeConfig = pytest.importorskip("hqq.core.quantize.BaseQuantizeConfig", reason="HQQLinear required to run this test") + from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 @@ -61,7 +55,7 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant **dict(group_size=group_size, axis=axis), } M, N, K = shape - + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") quant_config = BaseQuantizeConfig( @@ -81,19 +75,19 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant scales, zeros = meta["scale"], meta["zero"] scales = scales.reshape(N, -1) zeros = zeros.reshape(N, -1) - + if transposed: x = torch.randn(M, N, dtype=dtype, device="cuda") - hqq_out = x @ W_dq + hqq_out = x @ W_dq - #Pack uint8 W_q, then run fused dequant matmul + #Pack uint8 W_q, then run fused dequant matmul packed_w = pack_2xint4(W_q) tt_out = triton_mixed_mm( x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type ) else: x = torch.randn(M, K, dtype=dtype, device="cuda") - hqq_out = x @ W_dq.T + hqq_out = x @ W_dq.T packed_w = pack_2xint4(W_q.T) tt_out = triton_mixed_mm( @@ -101,4 +95,3 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant ) assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) -