Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
organized benchmarks
Browse files Browse the repository at this point in the history
vayuda committed Jun 10, 2024
1 parent 9ef4c6c commit d9a94c8
Showing 1 changed file with 101 additions and 137 deletions.
238 changes: 101 additions & 137 deletions benchmarks/benchmark_bitpacking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# from torchao.quantization.quant_primitives import dynamically_quantize_per_channel
from torchao.prototype.common.bitpacking import pack, unpack
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4
from math import log
# from torchao.utils import benchmark_utils
import torch

from torchao.prototype.common.bitpacking import pack, unpack
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4


def benchmark(setup, function, num_runs):
def benchmark(function, num_runs, setup =None):
args = setup()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
@@ -20,7 +19,8 @@ def benchmark(setup, function, num_runs):
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def test_existing():

def test_vs_existing():
def new_():
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
packed = pack(fake_tensor, 4, dim=1)
@@ -37,29 +37,28 @@ def old_():
print(f"old: {benchmark(old_, 1000)} ms")


def load4x(scale=1024):
def test_iso_bitpack():
def load4x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda()

def load2x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda()
def load2x(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda()

def loadx(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()

def unpack8to2(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 2, dim=1)

def unpack8to4(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
def loadx(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()

def t8to4wmm(scale=1024):
fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)

def test_iso_bitpack():
def unpack8to2(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 2, dim=1)

def unpack8to4(scale=1024):
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)

def t8to4wmm(scale=1024):
fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda()
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)

torch._dynamo.config.specialize_int = True
# _unpack_c = torch.compile(_unpack, fullgraph=True)
unpack_c = torch.compile(unpack, fullgraph=True)
@@ -98,122 +97,86 @@ def test_iso_bitpack():
# plt.legend()
# plt.savefig("benchmark_bitpacking.png")

import hqq
import hqq.core.quantize as hqq_quantize
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig

import itertools
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm


BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}


def check(expected, actual, msg="", max_diff=1e-3, verbose=False):
passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff)
if verbose:
max_err = (expected - actual).abs().max()
if not passed:
print_msg = f"{msg}:\nFailed! Max error: {max_err}"
try:
from termcolor import colored
except ImportError:
print(print_msg)
else:
print(colored(print_msg, "red", attrs=["bold"]))

else:
print_msg = f"{msg}:\nPassed! Max error: {max_err}"
try:
from termcolor import colored
except ImportError:
print(print_msg)
else:
print(colored(print_msg, "green", attrs=["bold"]))

return passed


def mixed_mm(
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
def test_vs_hqqpack():
#requires hqq to be installed
import hqq
import hqq.core.quantize as hqq_quantize
HQQLinear = hqq_quantize.HQQLinear
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm

BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
W_q, meta = hqq_linear.W_q, hqq_linear.meta
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
W_dq = hqq_linear.dequantize()

scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
if pack_fn:
packed_w = pack(W_q.T,4,dim=0,order=False)
else:
packed_w = pack_2xint4(W_q.T)
# print(W_q.T[0:5,0:5], W_q.T.shape)
# print(packed_w[0:5,0:5], W_q.T.shape)
# print(packed_w2[0:5,0:5], W_q.T.shape)
if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=True,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,

def mixed_mm(
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)

else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=False,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
quant_config.update({"weight_quant_params": qcfg})
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
W_q, meta = hqq_linear.W_q, hqq_linear.meta
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
# assert check(
# hqq_out,
# tt_out,
# max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3,
# verbose=True,
# )
W_dq = hqq_linear.dequantize()

scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
if pack_fn:
packed_w = pack(W_q.T,4,dim=0,order=False)
else:
packed_w = pack_2xint4(W_q.T)

if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=True,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)

def test_vs_hqqpack():
else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T

tt_out = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
transposed=False,
group_size=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)

shapes = [
[16, 128, 128],
[16, 4096, 4096],
@@ -257,7 +220,8 @@ def test_vs_hqqpack():
torch.uint8,
pack_fn=False))
print("")



if __name__ == "__main__":
test_existing()
test_vs_existing()

0 comments on commit d9a94c8

Please sign in to comment.