Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add custom CUDA tinygemm unpacker #415

Merged
merged 21 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 224 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import itertools

import torchao

import torch
from torch.testing._internal.common_utils import (
TestCase,
Expand All @@ -6,7 +10,7 @@
run_tests,
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode
from torchao.utils import is_fbcode, TORCH_VERSION_AFTER_2_5
from torchao.prototype.quant_llm import from_scaled_tc_fpx
import pytest

Expand All @@ -18,6 +22,14 @@
except RuntimeError:
pytest.skip("torchao.ops not available")

from torchao.quantization.utils import (
get_groupwise_affine_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
unpack_tinygemm_scales_and_zeros,
)


class TestOps(TestCase):
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
Expand Down Expand Up @@ -61,9 +73,218 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
relative_error = error / gt
assert relative_error < 1e-3


instantiate_parametrized_tests(TestOps)


## Tests for `tensor_core_layout`
kTileSizeN = 8
kTileSizeK = 16

SHAPES = [
(4096, 4096),
# Llama 2 GEMM shapes
(4096, 11008),
(11008, 4096),
# Llama 3 GEMM shapes
(4096, 14336),
(14336, 4096),
]
INNERKTILES = [2, 4, 8]
QGROUP_SIZES = [32, 64, 128, 256]
TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES))
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
N, K = shape
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
assert torch.equal(t, unpacked)

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]

# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AFTER_2_5:
test_utils.append("test_aot_dispatch_dynamic")

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)

opcheck(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which pytorch version are you using? it seems this opcheck is moved to torch.library.opcheck: https://github.com/pytorch/pytorch/pull/124496/files

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch 2.5.0.dev20240624+cu121

torch.ops.torchao.unpack_tensor_core_tiled_layout,
(packed_w, inner_k_tiles),
test_utils=test_utils,
)

def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
n, k = q.shape
assert q.dtype == torch.int

n_groups = k // group_size
assert scales.shape[0] == n and scales.shape[1] == n_groups
assert scales.shape == zeros.shape

midpoint = 2 ** (nbits - 1)

#Convert fron u4 -> s4 and upcast to bfloat16
q = q.sub(midpoint).to(dtype)

# Dequantize
q = q.reshape(-1, group_size)
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)

return dq.reshape(n, k)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16

device = "cuda"

t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)

# Quantize
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

# Pack to tensor core layout
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
q_groups = k // group_size
assert scales_and_zeros.shape == torch.Size([q_groups, n, 2])

# Dequantize 'ao' ref
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
q, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
a_eye,
packed,
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

assert diff_op_ao < 1e-1

# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16
device = "cuda"

# Quantize and pack
t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

# Unpack and dequantize
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
unpacked, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
a_eye,
packed,
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

assert diff_op_ao < 1e-1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
n, k = shape
device = "cuda"

q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles)
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]
# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AFTER_2_5:
test_utils.append("test_aot_dispatch_dynamic")
opcheck(
torch.ops.torchao.dequantize_tensor_core_tiled_layout,
(packed_w, scales_and_zeros, group_size, inner_k_tiles),
test_utils=test_utils,
)

if __name__ == "__main__":
run_tests()
run_tests()
Loading
Loading