forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FP16Act-FP6Weight Linear (pytorch#223)
* add files from fp6_llm * try to port weight packing first * rename * rename fp6 weight packing * add fp16act_fp6weight_linear * fix function def * delete duplicate file * move weight quant file * rename * add pytorch interface for fp6 weight dequant * add fake_fp6 to fp6 * move weight_quant to csrc/cuda due to cuda_fp16.h dependency * add fake_fp6_to_fp6 test * add test for fp16act_fp6weight_linear * add test for fp6_weight_dequant * Fp6WeightOnlyQuantizedLinearWeight (not working yet) * skip some tests, since the functions are not built w/o CUDA * add the original test * implement transpose and clone so that F.linear will work * remove print * remove dequantize * add notes and some rename * typo * small cleanup * improve tensor subclass and add test (which is failing for torch-compile) * add note * add note * add qtorch as dev requirement * update error message * add __repr__ and fix transposed issue * add fp6 perplexity test * rename variables * remove subclass * add correctness test * remove unwanted changes * add apache 2.0 notice * add benchmark script * add note about FP6 kernel * relax tolerance --------- Co-authored-by: Mark Saroufim <[email protected]>
- Loading branch information
Showing
17 changed files
with
1,882 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
import torchao | ||
from torch.utils.benchmark import Timer | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
|
||
def benchmark(m, k, n, splitK): | ||
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. | ||
fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int) | ||
fp16_scale = torch.rand(n).half() + 0.5 | ||
fp16_activation = torch.rand(m, k).half() + 0.5 | ||
|
||
fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) | ||
act_cuda = fp16_activation.cuda() | ||
weight_cuda = fp6_weight_packed.cuda() | ||
scale_cuda = fp16_scale.cuda() | ||
|
||
# need to do this since Timer cannot see torchao | ||
def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK): | ||
return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) | ||
|
||
fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK) | ||
|
||
fp6_measurement = Timer( | ||
stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)", | ||
globals=locals(), | ||
).blocked_autorange() | ||
|
||
fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() | ||
fp16_output = act_cuda @ fp16_weight.T | ||
|
||
fp16_measurement = Timer( | ||
stmt="act_cuda @ fp16_weight.T", | ||
globals=locals(), | ||
).blocked_autorange() | ||
|
||
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py | ||
# doesn't seem to be the right way to check for correctness | ||
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp6_latency (ms)": fp6_measurement.median * 1000, | ||
"fp16_latency (ms)": fp16_measurement.median * 1000, | ||
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median, | ||
"correct": correct, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
# from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (10240, 8192, 57344, 8192) | ||
|
||
results = [] | ||
|
||
# splitK can be tuned based on m, k, n | ||
for m, splitK_vals in tqdm([ | ||
(1, (5, 6, 7, 6)), | ||
(2, (5, 6, 7, 6)), | ||
(4, (5, 6, 7, 6)), | ||
(8, (5, 6, 7, 6)), | ||
# (16, (5, 6, 7, 6)), | ||
# (64, (5, 6, 7, 6)), | ||
# (128, (5, 3, 3, 3)), | ||
# (256, (4, 3, 2, 3)), | ||
# (512, (2, 5, 2, 4)), | ||
(1024, (1, 2, 1, 2)), | ||
(2048, (1, 1, 1, 1)), | ||
(4096, (1, 1, 1, 1)), | ||
# (8192, (1, 1, 1, 1)), | ||
# (16384, (1, 1, 1, 1)), | ||
]): | ||
for n, k, splitK in zip(n_vals, k_vals, splitK_vals): | ||
results.append(benchmark(m, n, k, splitK)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("fp6_benchmark_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright 2024 FP6-LLM authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
// | ||
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h | ||
|
||
#ifndef CONFIGS_H | ||
#define CONFIGS_H | ||
|
||
//#define DEBUG_MODE | ||
#define PIPELINE_LEVEL_GMEM 2 | ||
#define PIPELINE_LEVEL_SMEM 2 // only support 2 | ||
|
||
/************************ Hardware Parameters ************************/ | ||
#define WARP_SIZE 32 | ||
#define REG_BIT_WIDTH 32 | ||
// mma: M=16 K=16 N=8 | ||
#define MMA_8 8 | ||
#define MMA_16 16 | ||
// for memory access | ||
#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... | ||
#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 | ||
|
||
/******************** Register Allocation For GEMM ********************/ | ||
#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation | ||
/********************** Memory Padding Parameters **********************/ | ||
// Eliminating bank-conflict | ||
#define PADDING_BYTES_16 16 // Padding 16 bytes each column | ||
#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B | ||
#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C | ||
/************************* WARP Tiling part-1 *************************/ | ||
#define WARP_ROW_MMA_TENSORS 4 | ||
#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 | ||
#define WARP_K_MMA_TENSORS 4 | ||
#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 | ||
template<int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_> | ||
struct TilingConfig { | ||
// Depending on "n" dimension of the GEMM | ||
static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; | ||
static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; | ||
static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; | ||
/************************* WARP Tiling part-2 *************************/ | ||
static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; | ||
/*************************Thread Block Tiling *************************/ | ||
static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; | ||
static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; | ||
static constexpr int TILE_K = WARP_K; | ||
/********************** #Thread per Thread Block **********************/ | ||
static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; | ||
static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; | ||
/******************************* Others *******************************/ | ||
static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 | ||
static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 | ||
}; | ||
|
||
/************************ General Config for FP6-LLM **********************/ | ||
#define WEIGHT_FRAG1_BIT_WIDTH 2 | ||
#define WEIGHT_FRAG2_BIT_WIDTH 4 | ||
#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6 | ||
//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256 | ||
/*************************** 64*64 Weghts of A WARP *************************/ | ||
#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64 | ||
#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration | ||
#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration | ||
#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. | ||
#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. | ||
/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ | ||
#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64 | ||
#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128 | ||
/******************** Register Allocation For QUANTIZED DATA ******************/ | ||
#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128 | ||
#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8 | ||
#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16 | ||
/******************** Register Allocation For QUANT Scales ******************/ | ||
#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers | ||
#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread | ||
|
||
|
||
|
||
#endif // CONFIGS_H |
Oops, something went wrong.