Skip to content

Commit

Permalink
Add FP16Act-FP6Weight Linear (pytorch#223)
Browse files Browse the repository at this point in the history
* add files from fp6_llm

* try to port weight packing first

* rename

* rename fp6 weight packing

* add fp16act_fp6weight_linear

* fix function def

* delete duplicate file

* move weight quant file

* rename

* add pytorch interface for fp6 weight dequant

* add fake_fp6 to fp6

* move weight_quant to csrc/cuda due to cuda_fp16.h dependency

* add fake_fp6_to_fp6 test

* add test for fp16act_fp6weight_linear

* add test for fp6_weight_dequant

* Fp6WeightOnlyQuantizedLinearWeight (not working yet)

* skip some tests, since the functions are not built w/o CUDA

* add the original test

* implement transpose and clone so that F.linear will work

* remove print

* remove dequantize

* add notes and some rename

* typo

* small cleanup

* improve tensor subclass and add test (which is failing for torch-compile)

* add note

* add note

* add qtorch as dev requirement

* update error message

* add __repr__ and fix transposed issue

* add fp6 perplexity test

* rename variables

* remove subclass

* add correctness test

* remove unwanted changes

* add apache 2.0 notice

* add benchmark script

* add note about FP6 kernel

* relax tolerance

---------

Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
2 people authored and lancerts committed May 17, 2024
1 parent b226a7e commit 0cf842c
Show file tree
Hide file tree
Showing 17 changed files with 1,882 additions and 2 deletions.
82 changes: 82 additions & 0 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torchao
from torch.utils.benchmark import Timer
import pandas as pd
from tqdm import tqdm


def benchmark(m, k, n, splitK):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(n).half() + 0.5
fp16_activation = torch.rand(m, k).half() + 0.5

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

# need to do this since Timer cannot see torchao
def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK):
return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp6_measurement = Timer(
stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)",
globals=locals(),
).blocked_autorange()

fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
fp16_output = act_cuda @ fp16_weight.T

fp16_measurement = Timer(
stmt="act_cuda @ fp16_weight.T",
globals=locals(),
).blocked_autorange()

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3

return {
"m": m,
"k": k,
"n": n,
"fp6_latency (ms)": fp6_measurement.median * 1000,
"fp16_latency (ms)": fp16_measurement.median * 1000,
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
"correct": correct,
}


if __name__ == "__main__":
# from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh
k_vals = (8192, 8192, 8192, 28672)
n_vals = (10240, 8192, 57344, 8192)

results = []

# splitK can be tuned based on m, k, n
for m, splitK_vals in tqdm([
(1, (5, 6, 7, 6)),
(2, (5, 6, 7, 6)),
(4, (5, 6, 7, 6)),
(8, (5, 6, 7, 6)),
# (16, (5, 6, 7, 6)),
# (64, (5, 6, 7, 6)),
# (128, (5, 3, 3, 3)),
# (256, (4, 3, 2, 3)),
# (512, (2, 5, 2, 4)),
(1024, (1, 2, 1, 2)),
(2048, (1, 1, 1, 1)),
(4096, (1, 1, 1, 1)),
# (8192, (1, 1, 1, 1)),
# (16384, (1, 1, 1, 1)),
]):
for n, k, splitK in zip(n_vals, k_vals, splitK_vals):
results.append(benchmark(m, n, k, splitK))

df = pd.DataFrame(results)
df.to_csv("fp6_benchmark_results.csv", index=False)
print(df.to_markdown(index=False))
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def get_extensions():

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True))

if use_cuda:
sources += cuda_sources
Expand Down
93 changes: 93 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torchao
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
import unittest
from parameterized import parameterized


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
Expand Down Expand Up @@ -42,6 +43,98 @@ def test_nms(self):
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils)

def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight, fp16_scale, fp16_activation

def test_prepack_fp6_weight(self):
OC = 256
IC = 256
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.prepack_fp6_weight(fp6_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6(self):
OC = 256
IC = 256

# in this fp6, we use 3 bits for exponent and 2 bits for mantissa
# also, we don't have nan/inf
fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11
fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number)
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
fp16_weight.clip_(-fp6_absmax, fp6_absmax)
fp16_weight[fp16_weight.abs() < fp6_absmin] = 0

# smoke test
torchao.ops.fp16_to_fp6(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

# smoke test
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_weight_dequant(self):
OC = 256
IC = 256
fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
results_fp16 = act_cuda @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
assert relative_error.mean() < 1e-2


if __name__ == "__main__":
unittest.main()
90 changes: 90 additions & 0 deletions torchao/csrc/cuda/fp6_llm/configs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 FP6-LLM authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h

#ifndef CONFIGS_H
#define CONFIGS_H

//#define DEBUG_MODE
#define PIPELINE_LEVEL_GMEM 2
#define PIPELINE_LEVEL_SMEM 2 // only support 2

/************************ Hardware Parameters ************************/
#define WARP_SIZE 32
#define REG_BIT_WIDTH 32
// mma: M=16 K=16 N=8
#define MMA_8 8
#define MMA_16 16
// for memory access
#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ...
#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16

/******************** Register Allocation For GEMM ********************/
#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation
/********************** Memory Padding Parameters **********************/
// Eliminating bank-conflict
#define PADDING_BYTES_16 16 // Padding 16 bytes each column
#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B
#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C
/************************* WARP Tiling part-1 *************************/
#define WARP_ROW_MMA_TENSORS 4
#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64
#define WARP_K_MMA_TENSORS 4
#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64
template<int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
struct TilingConfig {
// Depending on "n" dimension of the GEMM
static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
/************************* WARP Tiling part-2 *************************/
static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
/*************************Thread Block Tiling *************************/
static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
static constexpr int TILE_K = WARP_K;
/********************** #Thread per Thread Block **********************/
static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
/******************************* Others *******************************/
static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2
static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4
};

/************************ General Config for FP6-LLM **********************/
#define WEIGHT_FRAG1_BIT_WIDTH 2
#define WEIGHT_FRAG2_BIT_WIDTH 4
#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6
//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256
/*************************** 64*64 Weghts of A WARP *************************/
#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64
#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration
#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration
#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB.
#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB.
/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/
#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64
#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128
/******************** Register Allocation For QUANTIZED DATA ******************/
#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128
#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8
#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16
/******************** Register Allocation For QUANT Scales ******************/
#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers
#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread



#endif // CONFIGS_H
Loading

0 comments on commit 0cf842c

Please sign in to comment.