-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import pytest | ||
|
||
triton = pytest.importorskip( | ||
"triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test" | ||
) | ||
hqq = pytest.importorskip("hqq", reason="hqq required to run this test") | ||
hqq_quantize = pytest.importorskip( | ||
"hqq.core.quantize", reason="hqq required to run this test" | ||
) | ||
HQQLinear = hqq_quantize.HQQLinear | ||
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig | ||
|
||
import itertools | ||
|
||
import torch | ||
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer | ||
|
||
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm | ||
|
||
torch.manual_seed(0) | ||
# N, K = shape | ||
Q_SHAPES = [[4096, 4096]] | ||
KV_SHAPES = [[4096, 4096], [1024, 4096]] | ||
GROUP_SIZES = [64, 128] | ||
AXES = [1] | ||
DTYPES = [torch.bfloat16] | ||
|
||
TRANSPOSED = [False, True] | ||
TRITON_KERNEL_TYPE = ["compute_bound"] | ||
TEST_CONFIGS = list( | ||
itertools.product( | ||
Q_SHAPES, KV_SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE | ||
) | ||
) | ||
|
||
|
||
BASE_QUANT_CONFIG = { | ||
"optimize": True, | ||
"view_as_float": False, | ||
"nbits": 4, | ||
"bitpack": False, | ||
"axis": 1, | ||
} | ||
|
||
|
||
def _arg_to_id(arg): | ||
if isinstance(arg, list): | ||
return "x".join([str(x) for x in arg]) | ||
return str(arg) | ||
|
||
|
||
def quantize_helper( | ||
weight_shape, quant_config, dtype, device="cuda", quant_dtype=torch.uint8 | ||
): | ||
N, K = weight_shape | ||
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) | ||
|
||
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) | ||
W_q, meta = hqq_linear.W_q, hqq_linear.meta | ||
W_q = W_q.to(dtype=quant_dtype) | ||
W_q = ( | ||
W_q.reshape(meta["shape"]) | ||
if quant_config["weight_quant_params"]["bitpack"] == False | ||
else W_q | ||
) | ||
|
||
scale, zero = meta["scale"], meta["zero"] | ||
scale = scale.reshape(N, -1) | ||
zero = zero.reshape(N, -1) | ||
|
||
return W_q, scale, zero | ||
|
||
|
||
def fuse_qkv(W_qs, scales, zeros): | ||
""" | ||
Args: | ||
W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv | ||
scales (list[torch.Tensor]): each is N x (K // group_size), with same N requirements per W_qs | ||
zeros (list[torch.Tensor]): same as scales | ||
Returns: | ||
qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv | ||
scales (torch.Tensor): (N_qkv x (K // group_size)) | ||
zeros (torch.Tensor): (N_qkv x (K // group_size)) | ||
""" | ||
qkv = torch.cat(W_qs, dim=0) # Fuse along N | ||
fused_scales = torch.cat([s for s in scales], dim=0) | ||
fused_zeros = torch.cat([z for z in zeros], dim=0) | ||
return qkv, fused_scales, fused_zeros | ||
|
||
|
||
def ref_proj(x, packed_w, scale, zero, group_size, kernel_type, transposed=False): | ||
return triton_mixed_mm( | ||
x, | ||
packed_w, | ||
scale.T, | ||
zero.T, | ||
transposed=transposed, | ||
group_size=group_size, | ||
fp8_fast_accum=False, | ||
kernel_type=kernel_type, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"q_shape, kv_shape, group_size, axis, dtype, transposed, kernel_type", | ||
TEST_CONFIGS, | ||
ids=_arg_to_id, | ||
) | ||
def test_mixed_mm( | ||
q_shape, | ||
kv_shape, | ||
group_size, | ||
axis, | ||
dtype, | ||
transposed, | ||
kernel_type, | ||
seqlen=16, | ||
device="cuda", | ||
quant_dtype=torch.uint8, | ||
): | ||
""" | ||
Note we test with dtype float32 in the transposed case, since fused and non-fused ops are not exactly equivalent in this case. | ||
More specifically when running transposed matmul: | ||
- fused: we are reducing along fused N within the kernel | ||
- non-fused: we are launching 3 individual kernels and reducing along N within each of these kernels for q, k, v then post-hoc summing these three terms to simulate the fused op | ||
This gives rise to a number of numeric issues when testing equivalence, given how accumulation is treated within triton MAC loop. | ||
Using higher precision mitigates these issues for the purposes of this test. | ||
""" | ||
|
||
# Override dtype per the above comment | ||
if transposed: | ||
dtype = torch.float32 | ||
|
||
qcfg = { | ||
**BASE_QUANT_CONFIG, | ||
**dict(group_size=group_size, axis=axis), | ||
} | ||
|
||
quant_config = BaseQuantizeConfig( | ||
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False | ||
) | ||
quant_config.update({"weight_quant_params": qcfg}) | ||
|
||
# Quantize q, k, v individually | ||
W_qs, packed_ws, scales, zeros = [], [], [], [] | ||
for shape in [q_shape, kv_shape, kv_shape]: | ||
W_q, scale, zero = quantize_helper( | ||
shape, quant_config, dtype, device, quant_dtype | ||
) | ||
W_qs.append(W_q) | ||
packed_ws.append(pack_2xint4(W_q.T)) | ||
scales.append(scale) | ||
zeros.append(zero) | ||
|
||
# Fuse q, k, v, scales, zeros | ||
qkv_fused, scales_fused, zeros_fused = fuse_qkv(W_qs, scales, zeros) | ||
qkv_fused_packed = pack_2xint4(qkv_fused.T) | ||
|
||
Ks = [shape[1] for shape in [q_shape, kv_shape]] | ||
|
||
K = Ks[0] | ||
|
||
# Check shapes | ||
assert all([k == K for k in Ks]) | ||
assert qkv_fused_packed.shape[0] * 2 == qkv_fused.shape[1] == Ks[0] | ||
|
||
if transposed: | ||
Ns = [q_shape[0], kv_shape[0], kv_shape[0]] | ||
xs = [torch.randn(seqlen, n, dtype=dtype, device=device) for n in Ns] | ||
x_fused = torch.cat(xs, dim=1) | ||
q_ref, k_ref, v_ref = [ | ||
ref_proj(x, p, s, z, group_size, kernel_type, transposed=True) | ||
for x, p, s, z in zip(xs, packed_ws, scales, zeros) | ||
] | ||
tt_fused = triton_mixed_mm( | ||
x_fused, | ||
qkv_fused_packed, | ||
scales_fused.T, | ||
zeros_fused.T, | ||
transposed=True, | ||
group_size=group_size, | ||
fp8_fast_accum=False, | ||
kernel_type=kernel_type, | ||
) | ||
tt_ref = q_ref + k_ref + v_ref | ||
assert torch.allclose(tt_ref, tt_fused, atol=1e-4) | ||
else: | ||
x = torch.randn(seqlen, K, dtype=dtype, device=device) | ||
|
||
q_ref, k_ref, v_ref = [ | ||
ref_proj(x, p, s, z, group_size, kernel_type) | ||
for p, s, z in zip(packed_ws, scales, zeros) | ||
] | ||
|
||
tt_fused = triton_mixed_mm( | ||
x, | ||
qkv_fused_packed, | ||
scales_fused.T, | ||
zeros_fused.T, | ||
transposed=False, | ||
group_size=group_size, | ||
fp8_fast_accum=False, | ||
kernel_type=kernel_type, | ||
) | ||
qN, kN, vN = q_shape[0], kv_shape[0], kv_shape[0] | ||
q_fused, k_fused, v_fused = tt_fused.split([qN, kN, vN], dim=1) | ||
|
||
for ref, fused in zip([q_ref, k_ref, v_ref], [q_fused, k_fused, v_fused]): | ||
assert torch.allclose(ref, fused) |