From 4d41a9859b34e9a98f914c6fea0dfb52d1cff0f4 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 1 May 2024 20:47:32 -0700 Subject: [PATCH 1/4] Enable FSDP Test in CI --- dev-requirements.txt | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 8a8ed1e491..76a984d939 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,10 +1,15 @@ -pytest +# Test utilities +pytest==7.4.0 expecttest +unittest-xml-reporting parameterized packaging transformers + +# For prototype features and benchmarks bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers -matplotlib # needed for triton benchmarking -pandas # also for triton benchmarking -transformers #for galore testing +matplotlib +pandas + +# Custom CUDA Extensions ninja From f27517a3d49a3713adb303929fe47aa7d7049bc7 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 2 May 2024 09:08:50 -0700 Subject: [PATCH 2/4] yolo --- test/hqq/test_triton_mm.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 23f6c60f70..21f3820fb7 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,15 +1,11 @@ # Skip entire test if triton is not available, otherwise CI failure import pytest -try: - import triton - import hqq - if int(triton.__version__.split(".")[0]) < 3: - pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True) -except ImportError: - pytest.skip("triton and hqq required to run this test", allow_module_level=True) -import itertools -import torch +pytest.importorskip("triton", minversion="3.0.0", reason="triton >= 3.0.0 required to run this test") +pytest.importorskip("hqq", reason="hqq required to run this test") + +import triton +import hqq from hqq.core.quantize import HQQLinear, BaseQuantizeConfig from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 @@ -61,7 +57,7 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant **dict(group_size=group_size, axis=axis), } M, N, K = shape - + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") quant_config = BaseQuantizeConfig( @@ -81,19 +77,19 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant scales, zeros = meta["scale"], meta["zero"] scales = scales.reshape(N, -1) zeros = zeros.reshape(N, -1) - + if transposed: x = torch.randn(M, N, dtype=dtype, device="cuda") - hqq_out = x @ W_dq + hqq_out = x @ W_dq - #Pack uint8 W_q, then run fused dequant matmul + #Pack uint8 W_q, then run fused dequant matmul packed_w = pack_2xint4(W_q) tt_out = triton_mixed_mm( x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type ) else: x = torch.randn(M, K, dtype=dtype, device="cuda") - hqq_out = x @ W_dq.T + hqq_out = x @ W_dq.T packed_w = pack_2xint4(W_q.T) tt_out = triton_mixed_mm( @@ -101,4 +97,3 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant ) assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) - From 95dc58285b945bae9f86c36dbb27004422182b4d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 2 May 2024 09:29:43 -0700 Subject: [PATCH 3/4] yolo --- test/hqq/test_triton_mm.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 21f3820fb7..c7ac3ccab0 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,11 +1,8 @@ # Skip entire test if triton is not available, otherwise CI failure import pytest -pytest.importorskip("triton", minversion="3.0.0", reason="triton >= 3.0.0 required to run this test") -pytest.importorskip("hqq", reason="hqq required to run this test") - -import triton -import hqq +triton = pytest.importorskip("triton", minversion="3.0.0", reason="triton >= 3.0.0 required to run this test") +hqq = pytest.importorskip("hqq", reason="hqq required to run this test") from hqq.core.quantize import HQQLinear, BaseQuantizeConfig from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 From fc7b43da0e52817e739fd7a83e18211c7d5ec541 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 2 May 2024 09:51:09 -0700 Subject: [PATCH 4/4] yolo --- test/hqq/test_triton_mm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index c7ac3ccab0..471ede4250 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,10 +1,11 @@ # Skip entire test if triton is not available, otherwise CI failure import pytest -triton = pytest.importorskip("triton", minversion="3.0.0", reason="triton >= 3.0.0 required to run this test") +triton = pytest.importorskip("triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test") hqq = pytest.importorskip("hqq", reason="hqq required to run this test") +HQQLinear = pytest.importorskip("hqq.core.quantize.HQQLinear", reason="HQQLinear required to run this test") +BaseQuantizeConfig = pytest.importorskip("hqq.core.quantize.BaseQuantizeConfig", reason="HQQLinear required to run this test") -from hqq.core.quantize import HQQLinear, BaseQuantizeConfig from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4