From 1b0be66036ee1a50bdcad8a34b58102c8f21ff95 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Thu, 7 Aug 2025 12:48:49 +0000 Subject: [PATCH 1/5] Enhanced fused_transpose_split_quant with fp8 capability. --- paddle/phi/infermeta/fusion.cc | 11 +- paddle/phi/infermeta/fusion.h | 1 + .../gpu/fused_transpose_split_quant_kernel.cu | 114 ++++++++++++------ paddle/phi/ops/yaml/fused_ops.yaml | 3 +- python/paddle/incubate/nn/functional/fp8.py | 6 +- .../test_fused_transpose_split_quant_op.py | 91 ++++++++++---- 6 files changed, 160 insertions(+), 66 deletions(-) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0152ddcebc90bf..5a58a7bd36c188 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2421,16 +2421,17 @@ void FusedMultiTransformerInt8InferMeta( } void FusedTransposeSplitQuantInferMeta(const MetaTensor& x, + const MetaTensor& input_scales, const IntArray& tokens_per_expert, bool pow_2_scales, std::vector outs, std::vector scales) { PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::BFLOAT16, - common::errors::InvalidArgument( - "The dtype of Input(x) must be BFLOAT16, but received %s", - x.dtype())); + x.dtype() == DataType::BFLOAT16 || x.dtype() == DataType::FLOAT8_E4M3FN, + true, + common::errors::InvalidArgument("The dtype of Input(x) must be BFLOAT16 " + "or FLOAT8_E4M3FN, but received %s", + x.dtype())); auto x_dims = x.dims(); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 6e840d67ead536..0c6c03cc580c28 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -669,6 +669,7 @@ void FusedMultiTransformerInt8InferMeta( MetaTensor* out); void FusedTransposeSplitQuantInferMeta(const MetaTensor& x, + const MetaTensor& input_scales, const IntArray& tokens_per_expert, bool pow_2_scales, std::vector outs, diff --git a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu index 23ddde393f3dd2..ce9611bf278426 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu @@ -29,22 +29,33 @@ struct __align__(sizeof(T) * VecSize) VecType { } }; -template -__device__ void BlockLoad(const phi::bfloat16* input, +template +__device__ void BlockLoad(const InT* input, + const float* input_scales, __nv_bfloat16 x[8][4], size_t K) { + constexpr bool need_dequant = std::is_same_v; for (uint32_t i = 0; i < 8; i++) { - size_t off_m = blockIdx.x * size_t(128) + threadIdx.y + i * 16; + size_t local_off_M = threadIdx.y + i * 16; + size_t off_m = blockIdx.x * size_t(128) + local_off_M; size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; size_t offset = off_m * K + off_k; + float scale; + if constexpr (need_dequant) { + scale = input_scales[local_off_M]; + } for (uint32_t j = 0; j < 4; j += VecSize) { if (off_k + j * 32 < K) { size_t idx = offset + j * 32; - using LoadT = VecType<__nv_bfloat16, VecSize>; + using LoadT = VecType; LoadT data = *reinterpret_cast(input + idx); for (uint32_t k = 0; k < VecSize; k++) { - x[i][j + k] = data[k]; + if constexpr (need_dequant) { + x[i][j + k] = __float2bfloat16(static_cast(data[k]) * scale); + } else { + x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k])); + } } } } @@ -53,7 +64,7 @@ __device__ void BlockLoad(const phi::bfloat16* input, template __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], - float col_scale[128], + float scales[128], __nv_bfloat16* shm) { // reduce [(8), 16, 32, 4] => [16, 32, 4] __nv_bfloat16 warp_max[4]; @@ -85,7 +96,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], if (offset > 1) { shm[threadIdx.y * 128 + threadIdx.x + j * 32] = next_val; } else { - col_scale[threadIdx.x + j * 32] = + scales[threadIdx.x + j * 32] = ComputeScale<__nv_bfloat16, __nv_fp8_e4m3, Pow2Scales>( static_cast(next_val), 0.0f); } @@ -98,7 +109,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], template __device__ void BlockStoreScale(float* scale, size_t off_m, - float col_scale[128], + float scales[128], size_t K) { if (threadIdx.y < 4) { uint32_t off = threadIdx.y * 32 + threadIdx.x; @@ -107,7 +118,7 @@ __device__ void BlockStoreScale(float* scale, } else if constexpr (VecSize == 2) { off = (off / 64) * 64 + (off % 2) * 32 + (off % 64) / 2; } - float scale_out = 1.0f / col_scale[off]; + float scale_out = 1.0f / scales[off]; size_t idx_y = blockIdx.x - off_m / 128; size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x; size_t idx = idx_y * K + idx_x; @@ -139,23 +150,39 @@ __device__ void BlockStoreOut(OutT* out, } } -template +template __global__ void __launch_bounds__(512) - FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ input, + FusedTransposeSplitQuantKernel(const InT* __restrict__ input, + const float* __restrict__ input_scales, int64_t* __restrict__ meta, size_t num_experts, - size_t K) { + size_t K, + size_t k_scaled) { __shared__ OutT shm[128][129]; + __shared__ float scales[128]; + __shared__ size_t expert_info[2]; + + // 0. Load input_scales if float8 + if constexpr (std::is_same_v) { + const int tid = blockDim.y * threadIdx.x + threadIdx.y; + if (tid < 128) { + const size_t m_base = blockIdx.x * size_t(128); + const size_t k_base = blockIdx.y * size_t(128); + const size_t m_stride = k_scaled; + const size_t input_scale_offset = (m_base + tid) * m_stride + blockIdx.y; + scales[tid] = input_scales[input_scale_offset]; + } + __syncthreads(); + } int64_t* tokens_per_expert = meta; OutT** out_ptrs = reinterpret_cast(meta + num_experts); float** scale_ptrs = reinterpret_cast(meta + num_experts * 2); // 1. Load 128x128 elements from input __nv_bfloat16 x[8][4]; - BlockLoad(input, x, K); + BlockLoad(input, scales, x, K); // 2. Get expert index and offset of the current block - __shared__ size_t expert_info[2]; if (threadIdx.x == 0 && threadIdx.y == 0) { size_t idx_m = blockIdx.x * size_t(128); size_t off_m = 0, next_off_m = 0; @@ -171,22 +198,25 @@ __global__ void __launch_bounds__(512) expert_info[1] = off_m; } + if constexpr (std::is_same_v) { + __syncthreads(); + } + // 3. Calculate scale along the column - __shared__ float col_scale[128]; BlockColumnScale( - x, col_scale, reinterpret_cast<__nv_bfloat16*>(shm)); + x, scales, reinterpret_cast<__nv_bfloat16*>(shm)); // 4. Store scale const size_t expert_idx = expert_info[0]; const size_t off_m = expert_info[1]; - BlockStoreScale(scale_ptrs[expert_idx], off_m, col_scale, K); + BlockStoreScale(scale_ptrs[expert_idx], off_m, scales, K); // 5. Scale x and save into shared memory with transposed layout for (uint32_t i = 0; i < 8; i++) { for (uint32_t j = 0; j < 4; j += VecSize) { for (uint32_t k = 0; k < VecSize; k++) { float x_fp32 = static_cast(x[i][j + k]); - float x_scaled = x_fp32 * col_scale[threadIdx.x + (j + k) * 32]; + float x_scaled = x_fp32 * scales[threadIdx.x + (j + k) * 32]; shm[threadIdx.x * VecSize + j * 32 + k][i * 16 + threadIdx.y] = static_cast(x_scaled); } @@ -204,10 +234,11 @@ template void FusedTransposeSplitQuantKernel( const Context& dev_ctx, const DenseTensor& x, + const paddle::optional& input_scales, const std::vector& tokens_per_expert, bool pow_2_scales, std::vector outs, - std::vector scales) { + std::vector output_scales) { auto x_dims = x.dims(); const int64_t M = x_dims[0]; const int64_t K = x_dims[1]; @@ -221,8 +252,8 @@ void FusedTransposeSplitQuantKernel( if (outs[i] != nullptr) { dev_ctx.template Alloc(outs[i]); } - if (scales[i] != nullptr) { - dev_ctx.template Alloc(scales[i]); + if (output_scales[i] != nullptr) { + dev_ctx.template Alloc(output_scales[i]); } } @@ -245,8 +276,8 @@ void FusedTransposeSplitQuantKernel( for (size_t i = 0; i < num_experts; i++) { meta_ptr[num_experts * 2 + i] = - scales[i] != nullptr - ? reinterpret_cast(scales[i]->data()) + output_scales[i] != nullptr + ? reinterpret_cast(output_scales[i]->data()) : 0; } @@ -255,23 +286,35 @@ void FusedTransposeSplitQuantKernel( auto stream = dev_ctx.stream(); - dim3 grid(M / 128, (K + 127) / 128); + // pre-compute on CPU to reduce size_t division cost in kernel + const size_t k_scaled = (K + 127) / 128; + dim3 grid(M / 128, k_scaled); dim3 block(32, 16); -#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \ - FusedTransposeSplitQuantKernel \ - <<>>(x.data(), \ - meta_gpu.data(), \ - num_experts, \ - K); +#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type +#define LAUNCH_KERNEL(T, POW_2_SCALES, VEC_SIZE) \ + FusedTransposeSplitQuantKernel<<>>( \ + x.data(), \ + input_scales ? input_scales.get_ptr()->data() : nullptr, \ + meta_gpu.data(), \ + num_experts, \ + K, \ + k_scaled); +#define DISPATCH_DATATYPE(POW_2_SCALES, VEC_SIZE) \ + if (DTYPE_CASE(x.dtype(), BFLOAT16)) { \ + LAUNCH_KERNEL(phi::bfloat16, POW_2_SCALES, VEC_SIZE); \ + } else if (DTYPE_CASE(x.dtype(), FLOAT8_E4M3FN)) { \ + LAUNCH_KERNEL(phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE); \ + } #define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \ if (pow_2_scales) { \ - LAUNCH_KERNEL(true, VEC_SIZE); \ + DISPATCH_DATATYPE(true, VEC_SIZE); \ } else { \ - LAUNCH_KERNEL(false, VEC_SIZE); \ + DISPATCH_DATATYPE(false, VEC_SIZE); \ } if (K % 4 == 0) { @@ -296,7 +339,8 @@ PD_REGISTER_KERNEL(fused_transpose_split_quant, double, int, int64_t, - phi::dtype::bfloat16) { + phi::dtype::bfloat16, + phi::dtype::float8_e4m3fn) { kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT8_E4M3FN); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); } diff --git a/paddle/phi/ops/yaml/fused_ops.yaml b/paddle/phi/ops/yaml/fused_ops.yaml index 291147c33367bf..991b1ab8c0ab6d 100644 --- a/paddle/phi/ops/yaml/fused_ops.yaml +++ b/paddle/phi/ops/yaml/fused_ops.yaml @@ -916,12 +916,13 @@ support_dygraph_mode : true - op: fused_transpose_split_quant - args: (Tensor x, IntArray tokens_per_expert, bool pow_2_scales=false) + args: (Tensor x, Tensor input_scales, IntArray tokens_per_expert, bool pow_2_scales=false) output: Tensor[](out){tokens_per_expert.size()}, Tensor[](scales){tokens_per_expert.size()} infer_meta: func: FusedTransposeSplitQuantInferMeta kernel: func: fused_transpose_split_quant + optional: input_scales support_dygraph_mode : true - op: fused_weighted_swiglu_act_quant diff --git a/python/paddle/incubate/nn/functional/fp8.py b/python/paddle/incubate/nn/functional/fp8.py index 7c524b865ee96b..c430f5ee6076d2 100644 --- a/python/paddle/incubate/nn/functional/fp8.py +++ b/python/paddle/incubate/nn/functional/fp8.py @@ -173,7 +173,9 @@ def fused_swiglu_weighted_bwd( return _C_ops.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs) -def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False): +def fused_transpose_split_quant( + x, input_scales, tokens_per_expert, pow_2_scales=False +): """ Applies fused transpose, split, and quantization operation for Mixture of Experts (MoE) models. @@ -228,7 +230,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False): if in_dynamic_or_pir_mode(): return _C_ops.fused_transpose_split_quant( - x, tokens_per_expert, pow_2_scales + x, input_scales, tokens_per_expert, pow_2_scales ) diff --git a/test/legacy_test/test_fused_transpose_split_quant_op.py b/test/legacy_test/test_fused_transpose_split_quant_op.py index edfea14fc1f35d..6c8604ba2ea876 100644 --- a/test/legacy_test/test_fused_transpose_split_quant_op.py +++ b/test/legacy_test/test_fused_transpose_split_quant_op.py @@ -17,8 +17,20 @@ import paddle -def fused_transpose_split_quant_ref(x, tokens_per_expert, pow_2_scales): +def dequant_ref( + fp8_tensor: paddle.Tensor, scale: paddle.Tensor, block_size: int = 128 +) -> paddle.Tensor: + """Helper function to dequantize fp8 tensor to bf16""" + expanded_scale = paddle.repeat_interleave(scale, repeats=128, axis=-1) + # Handle non-aligned cases by truncating + expanded_scale = expanded_scale[:, : fp8_tensor.shape[-1]] + return (fp8_tensor.astype('float32') * expanded_scale).astype('bfloat16') + + +def fused_transpose_split_quant_ref(x, xscale, tokens_per_expert, pow_2_scales): shape = x.shape + if x.dtype == paddle.float8_e4m3fn: + x = dequant_ref(x, xscale) x = x.reshape([shape[0] // 128, 128, shape[1]]) amax = x.astype('float32').abs().max(axis=1) @@ -37,43 +49,76 @@ def fused_transpose_split_quant_ref(x, tokens_per_expert, pow_2_scales): return out, scale -def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales): +def test_fused_transpose_split_quant( + tokens_per_expert, seq_len, pow_2_scales, using_fp8=False +): x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16') - x = paddle.clip(x, min=-50, max=50) + if using_fp8: + x = x.cast('float8_e4m3fn') + xscale = ( + paddle.randn( + [sum(tokens_per_expert), (seq_len + 127) // 128], dtype='float32' + ) + if using_fp8 + else None + ) + # x = paddle.clip(x, min=-50, max=50) out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant( - x, tokens_per_expert, pow_2_scales + x, xscale, tokens_per_expert, pow_2_scales ) out_ref, scale_ref = fused_transpose_split_quant_ref( - x, tokens_per_expert, pow_2_scales + x, xscale, tokens_per_expert, pow_2_scales ) for t, t_ref in zip(out, out_ref): - np.testing.assert_allclose(t.astype('float32'), t_ref.astype('float32')) + try: + np.testing.assert_allclose( + t.astype('float32'), t_ref.astype('float32') + ) + except AssertionError as e: + print("AssertionError", e) for t, t_ref in zip(scale, scale_ref): - np.testing.assert_allclose(t, t_ref) + try: + np.testing.assert_allclose(t, t_ref) + except AssertionError as e: + print("AssertionError", e) def run(): - test_fused_transpose_split_quant([0, 0], 1024, False) - test_fused_transpose_split_quant([128, 2 * 128], 0, True) - test_fused_transpose_split_quant([128], 1, False) - test_fused_transpose_split_quant([0, 128, 0, 2 * 128], 127, True) - test_fused_transpose_split_quant([3 * 128, 4 * 128, 5 * 128], 233, False) - test_fused_transpose_split_quant( - [24 * 128, 128, 50 * 128, 16 * 128], 2162, True - ) - test_fused_transpose_split_quant( - [7 * 128, 29 * 128, 3 * 128, 128 * 128, 13 * 128], 4000, False - ) - test_fused_transpose_split_quant( - [18 * 128, 5 * 128, 24 * 128, 128, 6 * 128, 0, 27 * 128, 7 * 128], - 7168, - True, - ) + fp8_choice = [True, False] + for using_fp8 in fp8_choice: + test_fused_transpose_split_quant( + [0, 0], 1024, False, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [128, 2 * 128], 0, True, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant([128], 1, False, using_fp8=using_fp8) + test_fused_transpose_split_quant( + [0, 128, 0, 2 * 128], 127, True, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [3 * 128, 4 * 128, 5 * 128], 233, False, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [24 * 128, 128, 50 * 128, 16 * 128], 2162, True, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [7 * 128, 29 * 128, 3 * 128, 128 * 128, 13 * 128], + 4000, + False, + using_fp8=using_fp8, + ) + test_fused_transpose_split_quant( + [18 * 128, 5 * 128, 24 * 128, 128, 6 * 128, 0, 27 * 128, 7 * 128], + 7168, + True, + using_fp8=using_fp8, + ) if __name__ == '__main__': From 4c05c63052889fbc4d832497488abf9f871bc9df Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 8 Aug 2025 06:30:07 +0000 Subject: [PATCH 2/5] optimize performance. --- .../gpu/fused_transpose_split_quant_kernel.cu | 99 +++++++++---------- 1 file changed, 49 insertions(+), 50 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu index ce9611bf278426..091c7241d4392f 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu @@ -33,50 +33,58 @@ template __device__ void BlockLoad(const InT* input, const float* input_scales, __nv_bfloat16 x[8][4], - size_t K) { + size_t K, + size_t k_scaled) { constexpr bool need_dequant = std::is_same_v; + +#pragma unroll for (uint32_t i = 0; i < 8; i++) { - size_t local_off_M = threadIdx.y + i * 16; - size_t off_m = blockIdx.x * size_t(128) + local_off_M; - size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; - size_t offset = off_m * K + off_k; + const uint32_t local_off_M = threadIdx.y + i * 16; + const uint32_t off_m = blockIdx.x * 128 + local_off_M; + const uint32_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; + const size_t offset = off_m * K + off_k; + float scale; if constexpr (need_dequant) { - scale = input_scales[local_off_M]; + const uint32_t m_base = blockIdx.x * 128; + const uint32_t m_stride = k_scaled; + scale = input_scales[off_m * m_stride + blockIdx.y]; } +#pragma unroll for (uint32_t j = 0; j < 4; j += VecSize) { - if (off_k + j * 32 < K) { - size_t idx = offset + j * 32; - using LoadT = VecType; - LoadT data = *reinterpret_cast(input + idx); - for (uint32_t k = 0; k < VecSize; k++) { - if constexpr (need_dequant) { - x[i][j + k] = __float2bfloat16(static_cast(data[k]) * scale); - } else { - x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k])); - } + const size_t idx = offset + j * 32; + using LoadT = VecType; + LoadT data = *reinterpret_cast(input + idx); +#pragma unroll + for (uint32_t k = 0; k < VecSize; k++) { + if constexpr (need_dequant) { + x[i][j + k] = __float2bfloat16(static_cast(data[k]) * scale); + } else { + x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k])); } } } } } - template __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], float scales[128], __nv_bfloat16* shm) { // reduce [(8), 16, 32, 4] => [16, 32, 4] __nv_bfloat16 warp_max[4]; +#pragma unroll for (uint32_t i = 0; i < 8; i++) { +#pragma unroll for (uint32_t j = 0; j < 4; j++) { - __nv_bfloat16 t = BF16_ABS(x[i][j]); + const __nv_bfloat16 t = BF16_ABS(x[i][j]); warp_max[j] = i == 0 ? t : BF16_MAX(warp_max[j], t); } } // reduce [(16), 32, 4] => [8, 32, 4] if (threadIdx.y >= 8) { +#pragma unroll for (uint32_t j = 0; j < 4; j++) { shm[(threadIdx.y - 8) * 128 + threadIdx.x + j * 32] = warp_max[j]; } @@ -86,8 +94,9 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], // reduce [(8), 32, 4] => [32, 4] for (uint32_t offset = 8; offset > 0; offset /= 2) { if (threadIdx.y < offset) { +#pragma unroll for (uint32_t j = 0; j < 4; j++) { - __nv_bfloat16 other = + const __nv_bfloat16 other = offset == 8 ? warp_max[j] : shm[(threadIdx.y + offset) * 128 + threadIdx.x + j * 32]; @@ -119,9 +128,9 @@ __device__ void BlockStoreScale(float* scale, off = (off / 64) * 64 + (off % 2) * 32 + (off % 64) / 2; } float scale_out = 1.0f / scales[off]; - size_t idx_y = blockIdx.x - off_m / 128; - size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x; - size_t idx = idx_y * K + idx_x; + const size_t idx_y = blockIdx.x - off_m / 128; + const size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x; + const size_t idx = idx_y * K + idx_x; if (idx_x < K) { scale[idx] = scale_out; } @@ -134,14 +143,16 @@ __device__ void BlockStoreOut(OutT* out, size_t cur_tokens, const OutT shm[128][129], size_t K) { +#pragma unroll for (uint32_t i = 0; i < 8; i++) { - size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4; - size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16; - size_t idx = idx_k * cur_tokens + (idx_m - off_m); + const size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4; + const size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16; + const size_t idx = idx_k * cur_tokens + (idx_m - off_m); if (idx_k < K) { using StoreT = VecType; StoreT data; +#pragma unroll for (uint32_t j = 0; j < VecSize; j++) { data[j] = shm[i * 16 + threadIdx.y][threadIdx.x * 4 + j]; } @@ -152,35 +163,24 @@ __device__ void BlockStoreOut(OutT* out, template __global__ void __launch_bounds__(512) - FusedTransposeSplitQuantKernel(const InT* __restrict__ input, - const float* __restrict__ input_scales, - int64_t* __restrict__ meta, - size_t num_experts, - size_t K, - size_t k_scaled) { + /* */ FusedTransposeSplitQuantKernel( + const InT* __restrict__ input, + const float* __restrict__ input_scales, + int64_t* __restrict__ meta, + size_t num_experts, + size_t K, + size_t k_scaled) { __shared__ OutT shm[128][129]; - __shared__ float scales[128]; __shared__ size_t expert_info[2]; + __shared__ float scales[128]; // 用于存储列方向计算出的scale - // 0. Load input_scales if float8 - if constexpr (std::is_same_v) { - const int tid = blockDim.y * threadIdx.x + threadIdx.y; - if (tid < 128) { - const size_t m_base = blockIdx.x * size_t(128); - const size_t k_base = blockIdx.y * size_t(128); - const size_t m_stride = k_scaled; - const size_t input_scale_offset = (m_base + tid) * m_stride + blockIdx.y; - scales[tid] = input_scales[input_scale_offset]; - } - __syncthreads(); - } int64_t* tokens_per_expert = meta; OutT** out_ptrs = reinterpret_cast(meta + num_experts); float** scale_ptrs = reinterpret_cast(meta + num_experts * 2); // 1. Load 128x128 elements from input __nv_bfloat16 x[8][4]; - BlockLoad(input, scales, x, K); + BlockLoad(input, input_scales, x, K, k_scaled); // 2. Get expert index and offset of the current block if (threadIdx.x == 0 && threadIdx.y == 0) { @@ -198,10 +198,6 @@ __global__ void __launch_bounds__(512) expert_info[1] = off_m; } - if constexpr (std::is_same_v) { - __syncthreads(); - } - // 3. Calculate scale along the column BlockColumnScale( x, scales, reinterpret_cast<__nv_bfloat16*>(shm)); @@ -211,9 +207,12 @@ __global__ void __launch_bounds__(512) const size_t off_m = expert_info[1]; BlockStoreScale(scale_ptrs[expert_idx], off_m, scales, K); - // 5. Scale x and save into shared memory with transposed layout +// 5. Scale x and save into shared memory with transposed layout +#pragma unroll for (uint32_t i = 0; i < 8; i++) { +#pragma unroll for (uint32_t j = 0; j < 4; j += VecSize) { +#pragma unroll for (uint32_t k = 0; k < VecSize; k++) { float x_fp32 = static_cast(x[i][j + k]); float x_scaled = x_fp32 * scales[threadIdx.x + (j + k) * 32]; From 3f52f1fbee58e3cf2e47331b63c2bc5cfebf47d0 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 8 Aug 2025 06:44:10 +0000 Subject: [PATCH 3/5] Clean comment --- .../kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu index 091c7241d4392f..99beb61553a101 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu @@ -172,7 +172,7 @@ __global__ void __launch_bounds__(512) size_t k_scaled) { __shared__ OutT shm[128][129]; __shared__ size_t expert_info[2]; - __shared__ float scales[128]; // 用于存储列方向计算出的scale + __shared__ float scales[128]; // May be reused? Is it worthy? int64_t* tokens_per_expert = meta; OutT** out_ptrs = reinterpret_cast(meta + num_experts); From f8099f328fd98f12b6049b76daa1b5fa1d39968d Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Fri, 8 Aug 2025 06:45:35 +0000 Subject: [PATCH 4/5] clean miscs --- .../gpu/fused_transpose_split_quant_kernel.cu | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu index 99beb61553a101..16503aa32f263d 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu @@ -163,13 +163,12 @@ __device__ void BlockStoreOut(OutT* out, template __global__ void __launch_bounds__(512) - /* */ FusedTransposeSplitQuantKernel( - const InT* __restrict__ input, - const float* __restrict__ input_scales, - int64_t* __restrict__ meta, - size_t num_experts, - size_t K, - size_t k_scaled) { + FusedTransposeSplitQuantKernel(const InT* __restrict__ input, + const float* __restrict__ input_scales, + int64_t* __restrict__ meta, + size_t num_experts, + size_t K, + size_t k_scaled) { __shared__ OutT shm[128][129]; __shared__ size_t expert_info[2]; __shared__ float scales[128]; // May be reused? Is it worthy? From 5165e16abebead49528e49606419c9fa38523a56 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 11 Aug 2025 03:25:01 +0000 Subject: [PATCH 5/5] Fix example --- python/paddle/incubate/nn/functional/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/incubate/nn/functional/fp8.py b/python/paddle/incubate/nn/functional/fp8.py index c430f5ee6076d2..be61e7bdb72ae3 100644 --- a/python/paddle/incubate/nn/functional/fp8.py +++ b/python/paddle/incubate/nn/functional/fp8.py @@ -217,7 +217,7 @@ def fused_transpose_split_quant( >>> x = paddle.randn([384, 512], dtype='bfloat16') >>> x = paddle.clip(x, min=-50, max=50) >>> tokens_per_expert = [128, 128, 128] - >>> outs, scales = F.fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=True) + >>> outs, scales = F.fused_transpose_split_quant(x,None, tokens_per_expert, pow_2_scales=True) >>> print(outs[0].shape) [512, 128] >>> print(scales[0].shape)