Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused HQQ Quantization Gemm #153

Merged
merged 19 commits into from
Apr 25, 2024
Merged
134 changes: 134 additions & 0 deletions benchmarks/benchmark_hqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
from termcolor import colored
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to get rid of this dependency for merge, I'm fine with adding adding colors though so something like this should work

RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
RESET = "\033[0m"  # Resets the color to default.

name = "Alice"
print(f"{GREEN}Hello, {name}!{RESET}")


import pandas as pd
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4
from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4

from triton.testing import do_bench
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you print your pip list here? I'm having a lot of trouble finding which triton version you used

I tried the nightlies shown on the openai repo and I also tried make triton from inside pytorch repo and keep getting errors like

ImportError: cannot import name 'get_cuda_stream' from 'triton.runtime.jit' (/home/marksaroufim/.conda/envs/hqq/lib/python3.10/site-packages/triton/runtime/jit.py)

To make testing easier assume we'll be using https://github.com/pytorch/pytorch/blob/main/Makefile#L35



BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}


def bench_custom_kernel(x, W_q, scales, zeros, group_size, kernel_type="max_autotune", fp8_fast_accum=False):
packed_w = pack_2xint4(W_q.T)

def fn():
_ = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
group_size=group_size,
fp8_fast_accum=fp8_fast_accum,
kernel_type=kernel_type,
)

t = do_bench(fn)
return t


def bench_hqq(x, hqq_linear: HQQLinear):
def fn():
_ = hqq_linear.forward(x)

t = do_bench(fn)
return t


def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

x = torch.randn(M, K, dtype=dtype, device="cuda")
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})

hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)

# Reference
ref_time = bench_hqq(x, hqq_linear)

# Custom kernel
W_q, meta = hqq_linear.W_q, hqq_linear.meta
scales, zeros = meta["scale"], meta["zero"]

W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
W_q = W_q.to(dtype=quant_dtype)
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
tt_time = bench_custom_kernel(x, W_q, scales, zeros, group_size)

if dtype == torch.bfloat16:
_ = quant_config["weight_quant_params"].pop("bitpack")
hqq_int4mm = HQQLinearTorchWeightOnlyInt4(
linear, quant_config, compute_dtype=dtype, del_orig=False
)
int4_time = bench_hqq(x, hqq_int4mm)

print(colored(f"{shape=} {group_size=} {dtype=}:", attrs=["bold"]))

print(
colored(f"Ref: {ref_time:.4f}", "blue"),
colored(f"Triton: {tt_time:.4f}", "green"),
colored(f"Torch int4mm: {int4_time:.4f}", "yellow")
if dtype == torch.bfloat16
else "",
)
print()
return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None


SHAPES = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch I guess these shapes are fine for now but are there some specific shapes we're more interested in tracking on an ongoing basis if so I wish we could just make them part of our benchmark or test utilities

[16, 4096, 4096],
[32, 4096, 4096],
[128, 4096, 4096],
[256, 4096, 4096],
[512, 4096, 4096],
[1024, 4096, 4096],
]

DTYPES = [torch.bfloat16] # , torch.float16]
GROUP_SIZES = [128]

print(torch.cuda.get_device_properties(0))

HEADERS = [
"M",
"N",
"K",
"group_size",
"dtype",
"ref",
"triton",
"tinygemm",
]
data = []
for shape in SHAPES:
for group_size in GROUP_SIZES:
for dtype in DTYPES:
timings = run_benchmark(shape, group_size, dtype)
data.append((*shape, group_size, dtype, *timings))


df = pd.DataFrame(data, columns=HEADERS)
df.to_csv("benchmark_triton.csv", index=False)
Copy link
Member

@msaroufim msaroufim Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will lose this csv on CI unless its saved to some github artifact so unless this file is huge let's just print it for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

101 changes: 101 additions & 0 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import itertools
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both test and benchmark will require skips if triton is less than 3.0 (which is fine because nightlies now ship with 3.0.0) and if hqq is not installed

For hqq I'm fine if we add it as a dev dependency for now


import torch
from termcolor import colored

from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from hqq.kernels.custom_quant.triton import triton_mixed_mm, pack_2xint4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't seem to find where this is defined on stable or nightly releases

  File "/home/marksaroufim/ao/test/hqq/test_triton_mm.py", line 7, in <module>
    from hqq.kernels.custom_quant.triton import triton_mixed_mm, pack_2xint4
ModuleNotFoundError: No module named 'hqq.kernels.custom_quant'

from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4


#TODO: refactor to pytest

#Test configs
SHAPES = [
# [16, 128],
[16, 128, 128],
[16, 4096, 4096],
# [1024, 4096],
# [4096, 4096],
# [4096, 11008],
]

DTYPES = [torch.bfloat16, torch.float16]
GROUP_SIZES = [64, 128]
AXES = [1] #Only axis = 1 supported
TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"]
TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRITON_KERNEL_TYPE))

BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
# "quant_dtype": torch.uint8,
"bitpack": False,
"axis": 1,
}


def check(expected, actual, cfg_str, max_diff=1e-3):
passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff)
max_err = (expected - actual).abs().max()
if not passed:
print(colored(f"{cfg_str}: Failed! Max error: {max_err}", "red", attrs=["bold"]))
else:
print(colored(f"{cfg_str}: Passed! Max error: {max_err}", "green", attrs=["bold"]))

def test_mixed_mm(shape, group_size, axis, dtype, kernel_type, quant_dtype=torch.uint8):
# print(f"Test: {shape}, {group_size}, {axis}, {dtype}")
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

x = torch.randn(M, K, dtype=dtype, device="cuda")
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
W_q, meta = hqq_linear.W_q, hqq_linear.meta
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
scales, zeros = meta["scale"], meta["zero"]

#Reference
hqq_out = hqq_linear.forward(x)

##Triton
W_q = W_q.to(dtype=quant_dtype)
packed_w = pack_2xint4(W_q.T)
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
tt_out = triton_mixed_mm(
x, packed_w, scales.T, zeros.T, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)

cfg_str = f"Test config {shape} {group_size} {dtype}"
# err = (hqq_out - tt_out).abs().max()
check(hqq_out, tt_out, cfg_str + " triton", max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3)

if dtype == torch.bfloat16:
_ = quant_config["weight_quant_params"].pop("bitpack")
hqq_int4mm = HQQLinearTorchWeightOnlyInt4(
linear, quant_config, compute_dtype=dtype, del_orig=False
)
hqq_int4_out = hqq_int4mm.forward(x)
err = (hqq_int4_out - hqq_out).abs().max()
check(hqq_out, hqq_int4_out, cfg_str + " torch_tinygemm", max_diff=1e-2)

print()


for test in TEST_CONFIGS:
test_mixed_mm(*test)
43 changes: 43 additions & 0 deletions torchao/prototype/hqq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## Fused `int4 / fp16` Quant Matmul

Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like one of the 2 asymetric should be a symetric?


The kernel packs `u4 / s4` weights and fuses dequantization with the matmul.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: whu can't we generically do this with torch.compile @HDCharles

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work with torch.compile and there's a good speed-up (up to 4x compared to Pytorch), but a dequantize() CUDA kernel + torch.matmul is a bit faster.
I think the bitpacking should be done in such a way that torch.compile can fully optimize it.


- tested for `float16 / bfloat16` activations, scales, and zeros
- autotuned for both compute-bound and io-bound configs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could use memory bandwidth bound terminology instead

- assumes operand B of the `gemm` is is the quantized type.
- requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`.

### Performance

Initial benchmarking demonstrates promising results, scaling well across io-bound and compute-bound workloads:

| | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm |
| --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- |
| 0 | 16 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2675 | 0.0633 | 0.0382 |
| 1 | 32 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2669 | 0.0704 | 0.0649 |
| 2 | 128 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2689 | 0.0960 | 0.2523 |
| 3 | 256 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3268 | 0.1355 | 0.5192 |
| 4 | 512 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3628 | 0.2369 | 1.0892 |
| 5 | 1024 | 4096 | 4096 | 128 | torch.bfloat16 | 0.5133 | 0.4753 | 2.2016 |

- Times are in `ms`, see `benchmarks/benchmark_hqq.py`.
- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul).
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions.

GPU details:

```
_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once we figure out the installation issues I'll check to see if results repro on an H100

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apologies meant pip freeze

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@jeromeku jeromeku Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim

When you run on H100, can you run once with DISABLE_MMA_V3=1? It toggles Hopper specific specializations in triton. Curious to see how performance changes.

```

### NOTE

This implementation requires **`triton >= 3.0.0`**.

- Running tests / benchmarks requires installation of `hqq`:

```
pip install hqq
```
1 change: 1 addition & 0 deletions torchao/prototype/hqq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mixed_mm import triton_mixed_mm, pack_2xint4
Loading
Loading