From 7734f7928587442ad42caaeca69aea28f3ae4012 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 14 May 2024 23:52:48 +0800 Subject: [PATCH] Add FP16Act-FP6Weight Linear (#223) * add files from fp6_llm * try to port weight packing first * rename * rename fp6 weight packing * add fp16act_fp6weight_linear * fix function def * delete duplicate file * move weight quant file * rename * add pytorch interface for fp6 weight dequant * add fake_fp6 to fp6 * move weight_quant to csrc/cuda due to cuda_fp16.h dependency * add fake_fp6_to_fp6 test * add test for fp16act_fp6weight_linear * add test for fp6_weight_dequant * Fp6WeightOnlyQuantizedLinearWeight (not working yet) * skip some tests, since the functions are not built w/o CUDA * add the original test * implement transpose and clone so that F.linear will work * remove print * remove dequantize * add notes and some rename * typo * small cleanup * improve tensor subclass and add test (which is failing for torch-compile) * add note * add note * add qtorch as dev requirement * update error message * add __repr__ and fix transposed issue * add fp6 perplexity test * rename variables * remove subclass * add correctness test * remove unwanted changes * add apache 2.0 notice * add benchmark script * add note about FP6 kernel * relax tolerance --------- Co-authored-by: Mark Saroufim --- benchmarks/benchmark_fp6.py | 82 +++++++ setup.py | 4 +- test/test_ops.py | 93 ++++++++ torchao/csrc/cuda/fp6_llm/configs.h | 90 +++++++ torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 184 +++++++++++++++ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 188 +++++++++++++++ .../csrc/cuda/fp6_llm/kernel_reduction.cuh | 63 +++++ torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh | 75 ++++++ torchao/csrc/cuda/fp6_llm/ptx_mma.cuh | 129 ++++++++++ torchao/csrc/cuda/fp6_llm/utils_core.cuh | 216 +++++++++++++++++ torchao/csrc/cuda/fp6_llm/utils_gmem.cuh | 91 ++++++++ .../cuda/fp6_llm/utils_parallel_dequant.cuh | 127 ++++++++++ torchao/csrc/cuda/fp6_llm/weight_quant.cu | 219 +++++++++++++++++ torchao/csrc/fp6_llm/README.md | 7 + torchao/csrc/fp6_llm/fp6_llm.cpp | 11 + torchao/csrc/fp6_llm/weight_prepacking.cpp | 220 ++++++++++++++++++ torchao/ops.py | 85 +++++++ 17 files changed, 1882 insertions(+), 2 deletions(-) create mode 100644 benchmarks/benchmark_fp6.py create mode 100644 torchao/csrc/cuda/fp6_llm/configs.h create mode 100644 torchao/csrc/cuda/fp6_llm/fp6_linear.cu create mode 100644 torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/ptx_mma.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/utils_core.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/utils_gmem.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh create mode 100644 torchao/csrc/cuda/fp6_llm/weight_quant.cu create mode 100644 torchao/csrc/fp6_llm/README.md create mode 100644 torchao/csrc/fp6_llm/fp6_llm.cpp create mode 100644 torchao/csrc/fp6_llm/weight_prepacking.cpp diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py new file mode 100644 index 0000000000..abe21d2f7d --- /dev/null +++ b/benchmarks/benchmark_fp6.py @@ -0,0 +1,82 @@ +import torch +import torchao +from torch.utils.benchmark import Timer +import pandas as pd +from tqdm import tqdm + + +def benchmark(m, k, n, splitK): + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. + fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int) + fp16_scale = torch.rand(n).half() + 0.5 + fp16_activation = torch.rand(m, k).half() + 0.5 + + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) + act_cuda = fp16_activation.cuda() + weight_cuda = fp6_weight_packed.cuda() + scale_cuda = fp16_scale.cuda() + + # need to do this since Timer cannot see torchao + def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK): + return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + fp6_measurement = Timer( + stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)", + globals=locals(), + ).blocked_autorange() + + fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() + fp16_output = act_cuda @ fp16_weight.T + + fp16_measurement = Timer( + stmt="act_cuda @ fp16_weight.T", + globals=locals(), + ).blocked_autorange() + + # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py + # doesn't seem to be the right way to check for correctness + correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 + + return { + "m": m, + "k": k, + "n": n, + "fp6_latency (ms)": fp6_measurement.median * 1000, + "fp16_latency (ms)": fp16_measurement.median * 1000, + "speedup (d/s)": fp16_measurement.median / fp6_measurement.median, + "correct": correct, + } + + +if __name__ == "__main__": + # from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh + k_vals = (8192, 8192, 8192, 28672) + n_vals = (10240, 8192, 57344, 8192) + + results = [] + + # splitK can be tuned based on m, k, n + for m, splitK_vals in tqdm([ + (1, (5, 6, 7, 6)), + (2, (5, 6, 7, 6)), + (4, (5, 6, 7, 6)), + (8, (5, 6, 7, 6)), + # (16, (5, 6, 7, 6)), + # (64, (5, 6, 7, 6)), + # (128, (5, 3, 3, 3)), + # (256, (4, 3, 2, 3)), + # (512, (2, 5, 2, 4)), + (1024, (1, 2, 1, 2)), + (2048, (1, 1, 1, 1)), + (4096, (1, 1, 1, 1)), + # (8192, (1, 1, 1, 1)), + # (16384, (1, 1, 1, 1)), + ]): + for n, k, splitK in zip(n_vals, k_vals, splitK_vals): + results.append(benchmark(m, n, k, splitK)) + + df = pd.DataFrame(results) + df.to_csv("fp6_benchmark_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/setup.py b/setup.py index 7d4875cadb..5d1f32da2b 100644 --- a/setup.py +++ b/setup.py @@ -63,10 +63,10 @@ def get_extensions(): this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) if use_cuda: sources += cuda_sources diff --git a/test/test_ops.py b/test/test_ops.py index a569f24799..e260e86f0f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,6 +4,7 @@ import torchao from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 import unittest +from parameterized import parameterized # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): @@ -42,6 +43,98 @@ def test_nms(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils) + def _create_fp6_inputs(self, BS: int, OC: int, IC: int): + # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. + fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) + fp16_scale = torch.rand(OC).half() + 0.5 + fp16_activation = torch.rand(BS, IC).half() + 0.5 + return fp6_weight, fp16_scale, fp16_activation + + def test_prepack_fp6_weight(self): + OC = 256 + IC = 256 + fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC) + + # smoke test + torchao.ops.prepack_fp6_weight(fp6_weight) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fp16_to_fp6(self): + OC = 256 + IC = 256 + + # in this fp6, we use 3 bits for exponent and 2 bits for mantissa + # also, we don't have nan/inf + fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11 + fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number) + fp16_weight = torch.randn((OC, IC), dtype=torch.float16) + fp16_weight.clip_(-fp6_absmax, fp6_absmax) + fp16_weight[fp16_weight.abs() < fp6_absmin] = 0 + + # smoke test + torchao.ops.fp16_to_fp6(fp16_weight) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fp16act_fp6weight_linear(self): + BS = 2 + OC = 256 + IC = 256 + splitK = 1 + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) + + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) + act_cuda = fp16_activation.cuda() + weight_cuda = fp6_weight_packed.cuda() + scale_cuda = fp16_scale.cuda() + + # smoke test + torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fp6_weight_dequant(self): + OC = 256 + IC = 256 + fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC) + + # smoke test + torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils) + + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py + @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) + + fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) + act_cuda = fp16_activation.cuda() + weight_cuda = fp6_weight_packed.cuda() + scale_cuda = fp16_scale.cuda() + + results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + + fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() + results_fp16 = act_cuda @ fp16_weight.T + + error = (results_fp6 - results_fp16).abs() + relative_error = error / results_fp16.abs() + assert relative_error.mean() < 1e-2 + if __name__ == "__main__": unittest.main() diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h new file mode 100644 index 0000000000..0a642fc805 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/configs.h @@ -0,0 +1,90 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h + +#ifndef CONFIGS_H +#define CONFIGS_H + +//#define DEBUG_MODE +#define PIPELINE_LEVEL_GMEM 2 +#define PIPELINE_LEVEL_SMEM 2 // only support 2 + +/************************ Hardware Parameters ************************/ +#define WARP_SIZE 32 +#define REG_BIT_WIDTH 32 +// mma: M=16 K=16 N=8 +#define MMA_8 8 +#define MMA_16 16 +// for memory access +#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... +#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 + +/******************** Register Allocation For GEMM ********************/ +#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation +/********************** Memory Padding Parameters **********************/ +// Eliminating bank-conflict +#define PADDING_BYTES_16 16 // Padding 16 bytes each column +#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B +#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C +/************************* WARP Tiling part-1 *************************/ +#define WARP_ROW_MMA_TENSORS 4 +#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 +#define WARP_K_MMA_TENSORS 4 +#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 +template +struct TilingConfig { + // Depending on "n" dimension of the GEMM + static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; + static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; + static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; + /************************* WARP Tiling part-2 *************************/ + static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; + /*************************Thread Block Tiling *************************/ + static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; + static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; + static constexpr int TILE_K = WARP_K; + /********************** #Thread per Thread Block **********************/ + static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; + static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; + /******************************* Others *******************************/ + static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 + static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 +}; + +/************************ General Config for FP6-LLM **********************/ +#define WEIGHT_FRAG1_BIT_WIDTH 2 +#define WEIGHT_FRAG2_BIT_WIDTH 4 +#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6 +//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256 +/*************************** 64*64 Weghts of A WARP *************************/ +#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64 +#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. +/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ +#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64 +#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128 +/******************** Register Allocation For QUANTIZED DATA ******************/ +#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128 +#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8 +#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16 +/******************** Register Allocation For QUANT Scales ******************/ +#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers +#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread + + + +#endif // CONFIGS_H diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu new file mode 100644 index 0000000000..51413a0874 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -0,0 +1,184 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/fp6_linear.cu + +#include "kernel_matmul.cuh" +#include "kernel_reduction.cuh" + +#include +#include + +template +static void Kernel_Ex(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + OutputDataType *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ + #ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); + #endif + static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE, TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); + // + #ifdef DEBUG_MODE + printf("GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n", + GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); + printf("\n"); + #endif + QUANT_GEMM_Kernel<<>> + (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); +} + +/* + * + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4 *Weight, + const half *Scales, + const half *B, + half *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + int Split_K) +{ + assert(M_Global % 256 == 0); + assert(K_Global % 64 == 0); + assert(N_Global>0); + + // Work around to support more N shapes: + size_t N_PowerOf2; + if(N_Global>0 && N_Global<=8) N_PowerOf2 = 8; + if(N_Global>8 && N_Global<=16) N_PowerOf2 = 16; + if(N_Global>16 && N_Global<=32) N_PowerOf2 = 32; + if(N_Global>32 && N_Global<=64) N_PowerOf2 = 64; + if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; + if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + default: if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + } + } + else { + switch (N_PowerOf2) { + case 8: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + default: if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} + + +#include +#include +#include + +namespace torchao { +/* +Computes FP6-FP16 GEMM (PyTorch interface). + +[Mathmatical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. +After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. + +[Inputs] + _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _scales: tensor of shape [OC]; // half + splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. +[Outputs] + _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int64_t splitK=1) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_channels = _weights.size(0); + assert( num_in_channels%64 == 0 ); + assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched. + // + int M = num_out_channels; + int K = num_in_channels; + int N = num_in_feats; + // Input Tensors + auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + // Output Tensors + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); + at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); + auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + + fp6_linear_kernel(0, // Using default stream here. + weight, + scales, + in_feats, + out_feats, + M, + N, + K, + Reduction_Workspace, + splitK); + + return _out_feats; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::fp16act_fp6weight_linear", &fp6_linear_forward_cuda); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh new file mode 100644 index 0000000000..de7775ddce --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -0,0 +1,188 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh + +#include "configs.h" +#include "utils_gmem.cuh" +#include "utils_core.cuh" + +/* + * C = A*B + * A: row major with ahead-of-time layout transformation, FP6 + * B: col major, FP16 + * C: col major, FP16 + */ + template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half *B, + OutputDataType* C, + const size_t M_Global, const size_t N_Global, const size_t K_Global, + int Split_K) +{ + #ifdef DEBUG_MODE + assert(K_Global%TilingConfig::TILE_K==0); + assert(M_Global%TilingConfig::TILE_M==0); + assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); + #endif + // 2+4 weight split + const uint4* Weight1 = Weight; + const uint4* Weight2 = Weight1 + M_Global*K_Global*2/128; + // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + extern __shared__ __align__(128) half smem[]; + half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles + __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); + const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K/Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + if(BatchID(smem); + uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR+SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4; + // Pre-fetch of A tile + for(int i=0; i(AFrag_2BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4, WARP_StartGPTR_A1); + CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4, WARP_StartGPTR_A2); + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; + } + // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// + const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); + // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// + const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef PIPELINE_LEVEL_SMEM + uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 + uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) +#endif + float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; + for(int i=0; i(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); +#endif + // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + #pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) + { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 +#ifdef PIPELINE_LEVEL_SMEM + uint32_t* __restrict__ read2_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; + uint32_t* __restrict__ read2_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; +#endif + uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#ifdef PIPELINE_LEVEL_SMEM + half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#endif + half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); + CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + cp_async_group_commit(); + #ifdef PIPELINE_LEVEL_SMEM + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + core_mma_slice(c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + #else + PipelinedCoreLoop(c, read_SPTR, read_SPTR_Frag1, read_SPTR_Frag2, Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + // Updating global PTRs + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + #endif + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast (smem); + StoreToSharedMemoryFromRegister(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global memory. + OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; + for(size_t i=warpId; i::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + } +} diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh new file mode 100644 index 0000000000..c0e7c1918a --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh @@ -0,0 +1,63 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh + +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Used for the reduction of result matrix if Split-K is used +// Reduction_Workspace: (Split_K, M_Global, N_Global), column major +// C: (M_Global, N_Global), column major +// Each thread deals with 8 output elements, each elements is the sum of Split_K elements +// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp +// Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp +// GridSize = (M_Global*N_Global) / 256 + +#include +#include +#include + +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 + +__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) +{ + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + // Initializing Thread-Local Results + float Results[HALF_PER_128BIT]; + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; + // Reduction + for (int i = 0; i < Split_K; i++) { + #pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; + THREAD_GPTR_R += M_Global * N_Global; + } + // Writing to global memory + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +} diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh new file mode 100644 index 0000000000..c1d064f32a --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh @@ -0,0 +1,75 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh + +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Extended from CUTLASS's source code + +#ifndef PTX_CP_ASYNC_CUH +#define PTX_CP_ASYNC_CUH + +#include +#include +#include + +template +__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) +{ + static_assert(SizeInBytes == 16, "Size is not supported"); + unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); + asm volatile("{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), + "l"(global_ptr), + "n"(SizeInBytes)); +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +__device__ __forceinline__ void cp_async_group_commit() +{ + asm volatile("cp.async.commit_group;\n" ::); +} + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +__device__ __forceinline__ void cp_async_wait_group() +{ + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +// cp.async.wait_all is equivalent to : +// cp.async.commit_group; +// cp.async.wait_group 0; +__device__ __forceinline__ void cp_async_wait_all() +{ + asm volatile("cp.async.wait_all;\n" ::); +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh new file mode 100644 index 0000000000..d0985bd63d --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -0,0 +1,129 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh + +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +#ifndef PTX_MMA_CUH +#define PTX_MMA_CUH + +#include +#include +#include + +#include +#include "configs.h" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int slice_id) { + #ifdef DEBUG_MODE + static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory + #ifdef DEBUG_MODE + assert( warp_start_col==0 ); + #endif + + int col = (lane_id%8) + (lane_id/16)*8; + int row = (lane_id%16) / 8 * 8; + uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } + else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) + { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#else +// Debug: Whether ldmatrix.trans is required??? +// B is in column-major +template +__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int k_offset) { + #ifdef DEBUG_MODE + static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory + #ifdef DEBUG_MODE + assert( warp_start_col==0 ); + #endif + + int col = (lane_id%8) + (lane_id/16)*8; + int row = (lane_id%16) / 8 * 8; + uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][k_offset + row])); + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } + else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) + { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#endif + +__device__ __forceinline__ void +MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) +{ + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh new file mode 100644 index 0000000000..5bfc043ef6 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -0,0 +1,216 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh + +#ifndef UTILS_CORE_CUH +#define UTILS_CORE_CUH + +#include + +#include "configs.h" +#include "ptx_mma.cuh" +#include "utils_parallel_dequant.cuh" + + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { + SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) +{ + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, 0); + CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers +} + +template +__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching +{ + #ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block + #endif + const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Reigsters for accumulated FP32 results + + // Setting RPTRs for double buffers + uint32_t (*a_read )[4] = a; + uint32_t (*a_write)[4] = a; + uint32_t (*b_read )[4] = b; + uint32_t (*b_write)[4] = b; + if(slice_id%2==1) { b_write += NumRegSets_b; a_write += NumRegSets_a;} + else { b_read += NumRegSets_b; a_read += NumRegSets_a;} + + // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory) + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); + } + else { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 + } + } + } + + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, slice_id); + CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, slice_id); + Dequant_32FP6_4Way(a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers +} + +#else +// Old version with naive pipeline design +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) { + int lane_id = threadIdx.x % WARP_SIZE; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void PipelinedCoreLoop(float c[][REG_PER_THREAD_C_TENSOR_16_16], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* __restrict__ read_SPTR_Frag1, + uint32_t* __restrict__ read_SPTR_Frag2, + uint32_t* RPTR_Scales) +{ + #ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block + #endif + const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block + + // Reigsters to store FP32 results + uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2*2]; // double buffer is used + uint32_t a_2[4*2]; // double buffer is used + // Registers to store decompressed FP6 + uint32_t a [NumRegSets_a * 1][4]; // No double buffer + // Register to store FP16 B matrix (a slice) + uint32_t b [NumRegSets_b * 2][4]; // double buffer is used + + // Overlapped Smem and TC pipeline: pre-loading from shared to registers + CopyFromSharedToRegister_AFrag<2> (a_1, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4> (a_2, read_SPTR_Frag2); + B_FromSharedToReg (b, read_SPTR, 0); + + #pragma unroll + for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { + uint32_t (*b_read)[4] = b; + uint32_t (*b_write)[4] = b; + uint32_t *a_1_read = a_1; + uint32_t *a_1_write = a_1; + uint32_t *a_2_read = a_2; + uint32_t *a_2_write = a_2; + if(k%2==0) { + b_write += NumRegSets_b; + a_1_write += 2; + a_2_write += 4; + } + else { + b_read += NumRegSets_b; + a_1_read += 2; + a_2_read += 4; + } + // data loading + if (k + 1 < WARP_K_MMA_TENSORS) { + // updating SPTR for fragment1 and fragment2 + read_SPTR_Frag1 += 2*WARP_SIZE; + read_SPTR_Frag2 += 4*WARP_SIZE; + CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); + B_FromSharedToReg(b_write, read_SPTR, (k+1)*MMA_16); + } + // SIMT Dequant + Tensor Core computations + Dequant_32FP6_4Way(a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, dequantizing a slice each time + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if(TilingConfig::WARP_COL_MMA_TENSORS==1) + MMA_FP16_M16N8K16( c_uint_ptr[i], a[i], b_read[0] ); + else { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a[i], b_read[j] + 2 ); // c+4; b+2 + } + } + } + } +} +#endif // #ifdef PIPELINE_LEVEL_SMEM + +template +__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) +{ + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j/2)*WARP_ROW_MMA_TENSORS; + int RegOffset = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; + #pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) { + int row_offset = lane_id / 4; + if (r >= 2) row_offset += 8; + int col_offset = (lane_id % 4) * 2; + if (r%2==1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; + } + } + } +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh new file mode 100644 index 0000000000..5c37452e13 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -0,0 +1,91 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh + +#ifndef UTILS_GMEM_CUH +#define UTILS_GMEM_CUH + +#include +#include "configs.h" +#include "ptx_cp.async.cuh" + +/* + * Copying A1/A2 from global memory to shared memory. + * Usually 1024 or 2048 Bytes + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, + const uint4* GPTR, + bool pred_guard = true) { + #ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); + #endif + int lane_id = threadIdx.x % WARP_SIZE; + half* SPTR_HALF = reinterpret_cast(SPTR); + const half* GPTR_HALF = reinterpret_cast(GPTR); + SPTR_HALF += lane_id*8; + GPTR_HALF += lane_id*8; + #pragma unroll + for(int i=0; i( SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes + } + +} + +/* + * Copying 64 Quant Scales (FP16) from global memory to shared memory. + */ +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, + const half* GPTR_A_Scales) { + int lane_id = threadIdx.x % WARP_SIZE; + int Offset_Shared = lane_id*2; + int Offset_Global = lane_id/4 + (lane_id%4)*16; + for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; +} + +/* + * (1) Copying X rows * 64 columns of FP16 values, originally in row major + * (2) Copying 64 rows * X columns of FP16 values, originally in column major + * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) { + // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x%8) * 8; + // PTR for source global memory and target shared memory + GlobalPTR += line_id * GlobalStride + line_offset; + SharedPTR += line_id; + #pragma unroll + for (int i = 0; i < MaxIteration; i++) { + bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; + cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + // + GlobalPTR += NumOfGroups * GlobalStride; + SharedPTR += NumOfGroups; + } +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh new file mode 100644 index 0000000000..f6ce4cc046 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -0,0 +1,127 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh + +#ifndef UTILS_PARALLELDEQUANT_CUH +#define UTILS_PARALLELDEQUANT_CUH + +#include +#include +#include + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) { + *R2 = *R1 & 0x80808080; + *R1 = *R1 >> 2; + *R1 = *R1 & 0x1f1f1f1f; + *R2 = *R2 | *R1; + *R1 = *R2 & 0x9f009f00; + *R2 = *R2 & 0x009f009f; + *R2 = *R2 << 8; +} + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is NOT applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) { + //*R2 = *R1 & 0x80808080; + *R2 = *R1 & 0xc0c0c0c0; + *R1 = *R1 >> 2; + //*R1 = *R1 & 0x1f1f1f1f; + *R1 = *R1 & 0x0f0f0f0f; + *R2 = *R2 | *R1; + // + //*R1 = *R2 & 0x9f009f00; + //*R2 = *R2 & 0x009f009f; + *R1 = *R2 & 0xcf00cf00; + if( !(*R1 & 0x40000000) && (*R1 & 0x0c000000) ) *R1 = *R1 | 0x30000000; + if( !(*R1 & 0x00004000) && (*R1 & 0x00000c00) ) *R1 = *R1 | 0x00003000; + *R2 = *R2 & 0x00cf00cf; + if( !(*R2 & 0x00400000) && (*R2 & 0x000c0000) ) *R2 = *R2 | 0x00300000; + if( !(*R2 & 0x00000040) && (*R2 & 0x0000000c) ) *R2 = *R2 | 0x00000030; + // + *R2 = *R2 << 8; + //*R1 = 0x3c003c00; + //*R2 = 0x3c003c00; +} + +__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + uint32_t output; + half* output_half_ptr = reinterpret_cast(&output); + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(4096.0f)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(4096.0f)), Scale); + return output; +} + +__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], + u_int32_t __restrict__ *read_RPTR_Frag1, + u_int32_t __restrict__ *read_RPTR_Frag2, + u_int32_t *Scales) { + u_int32_t *OutputRegs = reinterpret_cast (Reg); + u_int32_t *Frag1_PTR = read_RPTR_Frag1; + u_int32_t *Frag2_PTR = read_RPTR_Frag2; + half *Scale_RPTR = reinterpret_cast(Scales); + u_int32_t Packed_FP6 = 0; + u_int32_t tmp = 0; + // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 + #pragma unroll(8) + for(int i=0; i<8; i++) { + // Frag1 + Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; + if(i%4==3) Frag1_PTR++; + else (*Frag1_PTR) = (*Frag1_PTR) << 2; + // Frag2 + tmp = (*Frag2_PTR) & 0xf0f0f0f0; + tmp = tmp >> 2; + if(i%2==1) Frag2_PTR++; + else (*Frag2_PTR) = (*Frag2_PTR) << 4; + // Packed_FP6 + Packed_FP6 = Packed_FP6 | tmp; + // + FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + // + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0] ); // Muliply FP16 scales + OutputRegs += 1; + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + OutputRegs += 1; + // Updating offset for FP16 scales for every two iterations + if(i%2==1) Scale_RPTR += 2; + } + +} + +/* + * + */ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { + int lane_id = threadIdx.x % WARP_SIZE; + uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); + uint32_t tmpReg = SPTR_uint[lane_id]; + #pragma unroll + for(int i=0; i<4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + } +} + +#endif diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu new file mode 100644 index 0000000000..d29f70be0c --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -0,0 +1,219 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h +// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h + +#include +#include +#include + +/* + * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. + */ +void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) +{ + // Constants for FP6 + constexpr int exponent_nbits_fp6 = 3; + constexpr int mantissa_nbits_fp6 = 2; + constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; + // Constants for FP16 + constexpr int exponent_nbits_fp16 = 5; + constexpr int mantissa_nbits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; + + int fp6_temp[4]; + + float absmin_nonzero_fp6 = 0.0625; + // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is + // the same with that in qtorch. + float absmax_fp6 = 28; + + for (int i = 0; i < 4; ++i) { + uint16_t source = FP16x4[i]; + float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); + if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || + fp6_value_abs > absmax_fp6) { + // TODO(zhen): a better way may be rounding it to the nearest FP6 value. + throw std::invalid_argument("Input value out of range for FP6."); + } + + // It is not safe to do shift operation on uint16_t. So we promote it to int. + int source_promote = int(source); + + int sign_bit = (source_promote >> 15); + // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' + int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; + // Extracting mantissa represented in FP16 + int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); + + int new_exp_bit; + int new_mant_bit; + + if (exp_bit == 0) { + // Subnormal FP16 number. Too small for FP6. + new_exp_bit = 0; + new_mant_bit = 0; + } else { + new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; + + // Deal with subnormal FP6 values. + int target_exp_val = exp_bit - exp_bias_fp16; + int min_fp6_exp_val = -exp_bias_fp6 + 1; + bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; + if (subnormal_fp6) { + // TODO(zhen): add the rounding logic. + new_exp_bit = 0; + // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we + // need to add it + new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> + (min_fp6_exp_val - target_exp_val); + } + } + + fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; + } + // Pack the values + FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); + FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); + FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; +} + +/* + * Function to prepack FP16 weights into continuous FP6 values. + * + * Parameters: + * weight_16bit: input weight in FP16, size M*K + * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 + * M, K: the shape of the weight + */ +void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, + uint8_t* weight_6bit_packed, + size_t M, + size_t K) +{ + // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). + if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } + size_t K_fp6_packed = K * 6 / 8; + // #pragma omp parallel for + for (auto m = 0; m < M; m++) { + uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; + uint16_t* ptr_16bit = weight_16bit + m * K; + for (auto k = 0; k < K; k += 4) { + cast_fp16_fp6(ptr_16bit, ptr_6bit); + ptr_16bit += 4; + ptr_6bit += 3; + } + } +} + +void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { + assert(M%64==0); // Currently, M must be a multiple of 64. + assert(K%64==0); // Currently, K must be a multiple of 64. + size_t TotalSizeInByte = M*K*6/8; + // + half* OutPTR = A_16bit_h; + for(size_t i=0; i>2)&0x1f); + unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); + B2 = (B2&0x80) | ((B2>>2)&0x1f); + unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); + B3 = (B3&0x80) | ((B3>>2)&0x1f); + unsigned char B4 = A_6bit_h[i*3+2]<<2; + B4 = (B4&0x80) | ((B4>>2)&0x1f); + half FP1, FP2, FP3, FP4; + unsigned char *PTR1, *PTR2, *PTR3, *PTR4; + PTR1 = reinterpret_cast(&FP1); + PTR2 = reinterpret_cast(&FP2); + PTR3 = reinterpret_cast(&FP3); + PTR4 = reinterpret_cast(&FP4); + PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU + PTR2[0] = 0; PTR2[1] = B2; + PTR3[0] = 0; PTR3[1] = B3; + PTR4[0] = 0; PTR4[1] = B4; + OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); + OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); + // + OutPTR +=4; + } +} + + +#include +#include +#include + +namespace torchao { + +// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 +at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) +{ + TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(fp16_tensor.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(fp16_tensor.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = fp16_tensor.size(0); + auto K = fp16_tensor.size(1); + TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); + + // Pack weight from FP16 to FP6. + auto options = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto packed_fp6_tensor = at::empty({M, K * 6 / 8}, options); + uint8_t* packed_fp6_ptr = packed_fp6_tensor.data_ptr(); + + uint16_t* fake_fp6_ptr = reinterpret_cast(fp16_tensor.data_ptr()); + weight_prepacking_fp16_to_fp6(fake_fp6_ptr, packed_fp6_ptr, M, K); + + return packed_fp6_tensor; +} + +/* + * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. + * A useful tool to construct input matrices for the FP16 GEMM baseline. + * [Input] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. + * [Output] + * fp16_tensor: half tensor of shape [OC, IC]. + */ +at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale) +{ + int OC = fp6_tensor.size(0); + TORCH_CHECK(fp6_tensor.size(1) % 3 == 0); + int IC = fp6_tensor.size(1) / 3 * 16; + TORCH_CHECK(fp16_scale.size(0) == OC); + // + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); + // + auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device()); + at::Tensor fp16_tensor = at::empty({OC, IC}, options); + auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); + // + DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr); + // + return fp16_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); + m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); +} + +} diff --git a/torchao/csrc/fp6_llm/README.md b/torchao/csrc/fp6_llm/README.md new file mode 100644 index 0000000000..ff764cc27d --- /dev/null +++ b/torchao/csrc/fp6_llm/README.md @@ -0,0 +1,7 @@ +# FP6-LLM kernel + +This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). + +On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. + +See https://github.com/pytorch/ao/pull/223 for some benchmark results. diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp new file mode 100644 index 0000000000..794c79df11 --- /dev/null +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -0,0 +1,11 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); + m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); + m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); + m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); +} diff --git a/torchao/csrc/fp6_llm/weight_prepacking.cpp b/torchao/csrc/fp6_llm/weight_prepacking.cpp new file mode 100644 index 0000000000..89a1171f5e --- /dev/null +++ b/torchao/csrc/fp6_llm/weight_prepacking.cpp @@ -0,0 +1,220 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h + +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); + Padded_FP6[3] = FP6_Array[2]<<2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); + Padded_FP6[7] = FP6_Array[5]<<2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) +{ + unsigned char out; + out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for(int t=0; t<4; t++) + { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2], Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2], Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); + Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); + Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); + Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); + Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); + } + // + for(int t=0; t<4; t++) + { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM + int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for(int i=0; i<16; i++) + Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<16; i++) + output |= ( Frags_2bit[i] >> (i*2) ); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM + int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for(int i=0; i<8; i++) + Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; + // + unsigned int output = 0x00000000; + for(int i=0; i<8; i++) + output |= ( Frags_4bit[i] >> (i*4) ); + // + *PTR_UINT = output; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = reinterpret_cast(packed_weights); + unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; + // + vector A_Segment_2bit[32]; + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K*6/8; + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) + { + for(size_t k=0; k<64/16; k++) + { + size_t row = i*64 + k*16; + size_t col = j*16; + unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; + unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; + unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; + // Dealing with each 16*16 blocks then... + for(int l=0; l<8; l++) Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], &A_Segment_4bit[l*4], StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M*K*2/8/32; + size_t BytesPerThread_4bit = M*K*4/8/32; + for(int i=0; i<32; i++) + { + assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); + assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); + } + // Pass-2: Optimizing coleasced global memory access + for(size_t i=0; i +#include + +namespace torchao { + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) +{ + size_t OC = fp6_tensor.size(0); + size_t IC = fp6_tensor.size(1); + TORCH_CHECK(IC % 3 == 0, "Expect packed input dim % 3 == 0, but receive ", IC, " instead."); + IC = IC * 16 / 3; + TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0), "Expect output dim % 256 == 0 and input dim % 64 == 0, but receive ", OC, " and ", IC, " instead."); + auto packed_tensor = at::empty_like(fp6_tensor); + auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); + return packed_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::prepack_fp6_weight", &weight_matrix_prepacking_cpu); +} + +} diff --git a/torchao/ops.py b/torchao/ops.py index 0931d32026..3a25dbf6db 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -21,3 +21,88 @@ def _(dets, scores, iou_threshold): ctx = torch._custom_ops.get_ctx() num_to_keep = ctx.create_unbacked_symint() return dets.new_empty(num_to_keep, dtype=torch.long) + + +def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: + """ + Pack FP6 tensor in a layout for use with FP6-LLM. See https://arxiv.org/abs/2401.14112 for more details. + + Arguments + fp6_weight: tightly-packed fp6_weight, inside a `torch.int32` container + + Returns + packed FP6 tensor for use with FP6-LLM, inside a `torch.int32` container + """ + return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) + + +@torch.library.impl_abstract("torchao::prepack_fp6_weight") +def _(fp6_weight): + torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") + return torch.empty_like(fp6_weight) + + +def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: + """ + Pack FP16 tensor (containing only FP6 values) into FP6 tensor. + """ + return torch.ops.torchao.fp16_to_fp6.default(fp16_tensor) + + +@torch.library.impl_abstract("torchao::fp16_to_fp6") +def _(fp16_tensor): + torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") + torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") + M, K = fp16_tensor.shape + torch._check(K % 4 == 0, lambda: f"second dimension must be a multiple of 4, got {K}") + return torch.empty((M, K * 6 // 8), dtype=torch.uint8, device=fp16_tensor.device) + + +def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: + """ + FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. + + Arguments + _in_feats: input activations in FP16 + _weights: packed FP6 weights. See :func:prepack_fp6_weight and :func:fp16_to_fp6 + _scales: scale + splitK: split K + + Returns + output of linear layer + """ + return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) + + +@torch.library.impl_abstract("torchao::fp16act_fp6weight_linear") +def _(_in_feats, _weights, _scales, splitK = 1): + torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") + torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") + torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") + torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") + torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") + torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") + + BS, IC = _in_feats.shape + OC, _ = _weights.shape + torch._check(IC / 16 * 3 == _weights.shape[1], lambda: "Dimensions mismatched") + torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") + + return _in_feats.new_empty((BS, OC)) + + +def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: + return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) + + +@torch.library.impl_abstract("torchao::fp6_weight_dequant") +def _(fp6_tensor, fp16_scale): + torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D") + torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}") + torch._check(fp16_scale.dim() == 1, lambda: f"scale should be a 2d tensor, got {fp16_scale.dim()}D") + torch._check(fp16_scale.dtype is torch.float16, lambda: f"scale must be FP16, got {fp16_scale.dtype}") + + OC, _IC = fp6_tensor.shape + torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") + + return fp16_scale.new_empty((OC, _IC * 16 // 3))