diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index e9f9d21398..0709035efa 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,23 +1,26 @@ import torch import pandas as pd -import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_floatx -from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType +import torchao +from torchao.dtypes.floatx import from_scaled_tc_floatx, to_scaled_tc_floatx from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): - float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) - fp16_weight = fp6_weight.dequantize(torch.half) + ebits = 3 + mbits = 2 - fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") - fp6_output = F.linear(fp16_act, fp6_weight) - fp16_output = F.linear(fp16_act, fp16_weight) + fp32_weight = torch.randn(n, k, device="cuda") + fp6_weight, scale = to_scaled_tc_floatx(fp32_weight, ebits, mbits) + fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5 - fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) - fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) + fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1) + + fp16_weight = from_scaled_tc_floatx(fp6_weight, ebits, mbits, scale).half() + fp16_output = torch.matmul(fp16_act, fp16_weight.T) + + fp6_time = benchmark_torch_function_in_microseconds(torchao.ops.quant_llm_linear, ebits, mbits, fp16_act, fp6_weight, scale, splitK=1) + fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, fp16_act, fp16_weight.T) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness