Skip to content

Commit

Permalink
Enable FSDP Test in CI (pytorch#207)
Browse files Browse the repository at this point in the history
* Enable FSDP Test in CI

* yolo

* yolo

* yolo
  • Loading branch information
msaroufim authored May 3, 2024
1 parent a049baf commit 1d58ced
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
13 changes: 9 additions & 4 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
pytest
# Test utilities
pytest==7.4.0
expecttest
unittest-xml-reporting
parameterized
packaging
transformers

# For prototype features and benchmarks
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
matplotlib # needed for triton benchmarking
pandas # also for triton benchmarking
transformers #for galore testing
matplotlib
pandas

# Custom CUDA Extensions
ninja
29 changes: 11 additions & 18 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# Skip entire test if triton is not available, otherwise CI failure
import pytest
try:
import triton
import hqq
if int(triton.__version__.split(".")[0]) < 3:
pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True)
except ImportError:
pytest.skip("triton and hqq required to run this test", allow_module_level=True)

import itertools
import torch

from hqq.core.quantize import HQQLinear, BaseQuantizeConfig

triton = pytest.importorskip("triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test")
hqq = pytest.importorskip("hqq", reason="hqq required to run this test")
HQQLinear = pytest.importorskip("hqq.core.quantize.HQQLinear", reason="HQQLinear required to run this test")
BaseQuantizeConfig = pytest.importorskip("hqq.core.quantize.BaseQuantizeConfig", reason="HQQLinear required to run this test")

from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4


Expand Down Expand Up @@ -61,7 +55,7 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
Expand All @@ -81,24 +75,23 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant
scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)

if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq
hqq_out = x @ W_dq

#Pack uint8 W_q, then run fused dequant matmul
#Pack uint8 W_q, then run fused dequant matmul
packed_w = pack_2xint4(W_q)
tt_out = triton_mixed_mm(
x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)
else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T
hqq_out = x @ W_dq.T

packed_w = pack_2xint4(W_q.T)
tt_out = triton_mixed_mm(
x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)

assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3)

0 comments on commit 1d58ced

Please sign in to comment.