diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0152ddcebc90bf..5a58a7bd36c188 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2421,16 +2421,17 @@ void FusedMultiTransformerInt8InferMeta( } void FusedTransposeSplitQuantInferMeta(const MetaTensor& x, + const MetaTensor& input_scales, const IntArray& tokens_per_expert, bool pow_2_scales, std::vector outs, std::vector scales) { PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::BFLOAT16, - common::errors::InvalidArgument( - "The dtype of Input(x) must be BFLOAT16, but received %s", - x.dtype())); + x.dtype() == DataType::BFLOAT16 || x.dtype() == DataType::FLOAT8_E4M3FN, + true, + common::errors::InvalidArgument("The dtype of Input(x) must be BFLOAT16 " + "or FLOAT8_E4M3FN, but received %s", + x.dtype())); auto x_dims = x.dims(); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 6e840d67ead536..0c6c03cc580c28 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -669,6 +669,7 @@ void FusedMultiTransformerInt8InferMeta( MetaTensor* out); void FusedTransposeSplitQuantInferMeta(const MetaTensor& x, + const MetaTensor& input_scales, const IntArray& tokens_per_expert, bool pow_2_scales, std::vector outs, diff --git a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu index 23ddde393f3dd2..16503aa32f263d 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu @@ -29,43 +29,62 @@ struct __align__(sizeof(T) * VecSize) VecType { } }; -template -__device__ void BlockLoad(const phi::bfloat16* input, +template +__device__ void BlockLoad(const InT* input, + const float* input_scales, __nv_bfloat16 x[8][4], - size_t K) { + size_t K, + size_t k_scaled) { + constexpr bool need_dequant = std::is_same_v; + +#pragma unroll for (uint32_t i = 0; i < 8; i++) { - size_t off_m = blockIdx.x * size_t(128) + threadIdx.y + i * 16; - size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; - size_t offset = off_m * K + off_k; + const uint32_t local_off_M = threadIdx.y + i * 16; + const uint32_t off_m = blockIdx.x * 128 + local_off_M; + const uint32_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; + const size_t offset = off_m * K + off_k; + + float scale; + if constexpr (need_dequant) { + const uint32_t m_base = blockIdx.x * 128; + const uint32_t m_stride = k_scaled; + scale = input_scales[off_m * m_stride + blockIdx.y]; + } +#pragma unroll for (uint32_t j = 0; j < 4; j += VecSize) { - if (off_k + j * 32 < K) { - size_t idx = offset + j * 32; - using LoadT = VecType<__nv_bfloat16, VecSize>; - LoadT data = *reinterpret_cast(input + idx); - for (uint32_t k = 0; k < VecSize; k++) { - x[i][j + k] = data[k]; + const size_t idx = offset + j * 32; + using LoadT = VecType; + LoadT data = *reinterpret_cast(input + idx); +#pragma unroll + for (uint32_t k = 0; k < VecSize; k++) { + if constexpr (need_dequant) { + x[i][j + k] = __float2bfloat16(static_cast(data[k]) * scale); + } else { + x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k])); } } } } } - template __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], - float col_scale[128], + float scales[128], __nv_bfloat16* shm) { // reduce [(8), 16, 32, 4] => [16, 32, 4] __nv_bfloat16 warp_max[4]; +#pragma unroll for (uint32_t i = 0; i < 8; i++) { +#pragma unroll for (uint32_t j = 0; j < 4; j++) { - __nv_bfloat16 t = BF16_ABS(x[i][j]); + const __nv_bfloat16 t = BF16_ABS(x[i][j]); warp_max[j] = i == 0 ? t : BF16_MAX(warp_max[j], t); } } // reduce [(16), 32, 4] => [8, 32, 4] if (threadIdx.y >= 8) { +#pragma unroll for (uint32_t j = 0; j < 4; j++) { shm[(threadIdx.y - 8) * 128 + threadIdx.x + j * 32] = warp_max[j]; } @@ -75,8 +94,9 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], // reduce [(8), 32, 4] => [32, 4] for (uint32_t offset = 8; offset > 0; offset /= 2) { if (threadIdx.y < offset) { +#pragma unroll for (uint32_t j = 0; j < 4; j++) { - __nv_bfloat16 other = + const __nv_bfloat16 other = offset == 8 ? warp_max[j] : shm[(threadIdx.y + offset) * 128 + threadIdx.x + j * 32]; @@ -85,7 +105,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], if (offset > 1) { shm[threadIdx.y * 128 + threadIdx.x + j * 32] = next_val; } else { - col_scale[threadIdx.x + j * 32] = + scales[threadIdx.x + j * 32] = ComputeScale<__nv_bfloat16, __nv_fp8_e4m3, Pow2Scales>( static_cast(next_val), 0.0f); } @@ -98,7 +118,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4], template __device__ void BlockStoreScale(float* scale, size_t off_m, - float col_scale[128], + float scales[128], size_t K) { if (threadIdx.y < 4) { uint32_t off = threadIdx.y * 32 + threadIdx.x; @@ -107,10 +127,10 @@ __device__ void BlockStoreScale(float* scale, } else if constexpr (VecSize == 2) { off = (off / 64) * 64 + (off % 2) * 32 + (off % 64) / 2; } - float scale_out = 1.0f / col_scale[off]; - size_t idx_y = blockIdx.x - off_m / 128; - size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x; - size_t idx = idx_y * K + idx_x; + float scale_out = 1.0f / scales[off]; + const size_t idx_y = blockIdx.x - off_m / 128; + const size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x; + const size_t idx = idx_y * K + idx_x; if (idx_x < K) { scale[idx] = scale_out; } @@ -123,14 +143,16 @@ __device__ void BlockStoreOut(OutT* out, size_t cur_tokens, const OutT shm[128][129], size_t K) { +#pragma unroll for (uint32_t i = 0; i < 8; i++) { - size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4; - size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16; - size_t idx = idx_k * cur_tokens + (idx_m - off_m); + const size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4; + const size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16; + const size_t idx = idx_k * cur_tokens + (idx_m - off_m); if (idx_k < K) { using StoreT = VecType; StoreT data; +#pragma unroll for (uint32_t j = 0; j < VecSize; j++) { data[j] = shm[i * 16 + threadIdx.y][threadIdx.x * 4 + j]; } @@ -139,23 +161,27 @@ __device__ void BlockStoreOut(OutT* out, } } -template +template __global__ void __launch_bounds__(512) - FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ input, + FusedTransposeSplitQuantKernel(const InT* __restrict__ input, + const float* __restrict__ input_scales, int64_t* __restrict__ meta, size_t num_experts, - size_t K) { + size_t K, + size_t k_scaled) { __shared__ OutT shm[128][129]; + __shared__ size_t expert_info[2]; + __shared__ float scales[128]; // May be reused? Is it worthy? + int64_t* tokens_per_expert = meta; OutT** out_ptrs = reinterpret_cast(meta + num_experts); float** scale_ptrs = reinterpret_cast(meta + num_experts * 2); // 1. Load 128x128 elements from input __nv_bfloat16 x[8][4]; - BlockLoad(input, x, K); + BlockLoad(input, input_scales, x, K, k_scaled); // 2. Get expert index and offset of the current block - __shared__ size_t expert_info[2]; if (threadIdx.x == 0 && threadIdx.y == 0) { size_t idx_m = blockIdx.x * size_t(128); size_t off_m = 0, next_off_m = 0; @@ -172,21 +198,23 @@ __global__ void __launch_bounds__(512) } // 3. Calculate scale along the column - __shared__ float col_scale[128]; BlockColumnScale( - x, col_scale, reinterpret_cast<__nv_bfloat16*>(shm)); + x, scales, reinterpret_cast<__nv_bfloat16*>(shm)); // 4. Store scale const size_t expert_idx = expert_info[0]; const size_t off_m = expert_info[1]; - BlockStoreScale(scale_ptrs[expert_idx], off_m, col_scale, K); + BlockStoreScale(scale_ptrs[expert_idx], off_m, scales, K); - // 5. Scale x and save into shared memory with transposed layout +// 5. Scale x and save into shared memory with transposed layout +#pragma unroll for (uint32_t i = 0; i < 8; i++) { +#pragma unroll for (uint32_t j = 0; j < 4; j += VecSize) { +#pragma unroll for (uint32_t k = 0; k < VecSize; k++) { float x_fp32 = static_cast(x[i][j + k]); - float x_scaled = x_fp32 * col_scale[threadIdx.x + (j + k) * 32]; + float x_scaled = x_fp32 * scales[threadIdx.x + (j + k) * 32]; shm[threadIdx.x * VecSize + j * 32 + k][i * 16 + threadIdx.y] = static_cast(x_scaled); } @@ -204,10 +232,11 @@ template void FusedTransposeSplitQuantKernel( const Context& dev_ctx, const DenseTensor& x, + const paddle::optional& input_scales, const std::vector& tokens_per_expert, bool pow_2_scales, std::vector outs, - std::vector scales) { + std::vector output_scales) { auto x_dims = x.dims(); const int64_t M = x_dims[0]; const int64_t K = x_dims[1]; @@ -221,8 +250,8 @@ void FusedTransposeSplitQuantKernel( if (outs[i] != nullptr) { dev_ctx.template Alloc(outs[i]); } - if (scales[i] != nullptr) { - dev_ctx.template Alloc(scales[i]); + if (output_scales[i] != nullptr) { + dev_ctx.template Alloc(output_scales[i]); } } @@ -245,8 +274,8 @@ void FusedTransposeSplitQuantKernel( for (size_t i = 0; i < num_experts; i++) { meta_ptr[num_experts * 2 + i] = - scales[i] != nullptr - ? reinterpret_cast(scales[i]->data()) + output_scales[i] != nullptr + ? reinterpret_cast(output_scales[i]->data()) : 0; } @@ -255,23 +284,35 @@ void FusedTransposeSplitQuantKernel( auto stream = dev_ctx.stream(); - dim3 grid(M / 128, (K + 127) / 128); + // pre-compute on CPU to reduce size_t division cost in kernel + const size_t k_scaled = (K + 127) / 128; + dim3 grid(M / 128, k_scaled); dim3 block(32, 16); -#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \ - FusedTransposeSplitQuantKernel \ - <<>>(x.data(), \ - meta_gpu.data(), \ - num_experts, \ - K); +#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type +#define LAUNCH_KERNEL(T, POW_2_SCALES, VEC_SIZE) \ + FusedTransposeSplitQuantKernel<<>>( \ + x.data(), \ + input_scales ? input_scales.get_ptr()->data() : nullptr, \ + meta_gpu.data(), \ + num_experts, \ + K, \ + k_scaled); +#define DISPATCH_DATATYPE(POW_2_SCALES, VEC_SIZE) \ + if (DTYPE_CASE(x.dtype(), BFLOAT16)) { \ + LAUNCH_KERNEL(phi::bfloat16, POW_2_SCALES, VEC_SIZE); \ + } else if (DTYPE_CASE(x.dtype(), FLOAT8_E4M3FN)) { \ + LAUNCH_KERNEL(phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE); \ + } #define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \ if (pow_2_scales) { \ - LAUNCH_KERNEL(true, VEC_SIZE); \ + DISPATCH_DATATYPE(true, VEC_SIZE); \ } else { \ - LAUNCH_KERNEL(false, VEC_SIZE); \ + DISPATCH_DATATYPE(false, VEC_SIZE); \ } if (K % 4 == 0) { @@ -296,7 +337,8 @@ PD_REGISTER_KERNEL(fused_transpose_split_quant, double, int, int64_t, - phi::dtype::bfloat16) { + phi::dtype::bfloat16, + phi::dtype::float8_e4m3fn) { kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT8_E4M3FN); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); } diff --git a/paddle/phi/ops/yaml/fused_ops.yaml b/paddle/phi/ops/yaml/fused_ops.yaml index 291147c33367bf..991b1ab8c0ab6d 100644 --- a/paddle/phi/ops/yaml/fused_ops.yaml +++ b/paddle/phi/ops/yaml/fused_ops.yaml @@ -916,12 +916,13 @@ support_dygraph_mode : true - op: fused_transpose_split_quant - args: (Tensor x, IntArray tokens_per_expert, bool pow_2_scales=false) + args: (Tensor x, Tensor input_scales, IntArray tokens_per_expert, bool pow_2_scales=false) output: Tensor[](out){tokens_per_expert.size()}, Tensor[](scales){tokens_per_expert.size()} infer_meta: func: FusedTransposeSplitQuantInferMeta kernel: func: fused_transpose_split_quant + optional: input_scales support_dygraph_mode : true - op: fused_weighted_swiglu_act_quant diff --git a/python/paddle/incubate/nn/functional/fp8.py b/python/paddle/incubate/nn/functional/fp8.py index 7c524b865ee96b..be61e7bdb72ae3 100644 --- a/python/paddle/incubate/nn/functional/fp8.py +++ b/python/paddle/incubate/nn/functional/fp8.py @@ -173,7 +173,9 @@ def fused_swiglu_weighted_bwd( return _C_ops.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs) -def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False): +def fused_transpose_split_quant( + x, input_scales, tokens_per_expert, pow_2_scales=False +): """ Applies fused transpose, split, and quantization operation for Mixture of Experts (MoE) models. @@ -215,7 +217,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False): >>> x = paddle.randn([384, 512], dtype='bfloat16') >>> x = paddle.clip(x, min=-50, max=50) >>> tokens_per_expert = [128, 128, 128] - >>> outs, scales = F.fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=True) + >>> outs, scales = F.fused_transpose_split_quant(x,None, tokens_per_expert, pow_2_scales=True) >>> print(outs[0].shape) [512, 128] >>> print(scales[0].shape) @@ -228,7 +230,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False): if in_dynamic_or_pir_mode(): return _C_ops.fused_transpose_split_quant( - x, tokens_per_expert, pow_2_scales + x, input_scales, tokens_per_expert, pow_2_scales ) diff --git a/test/legacy_test/test_fused_transpose_split_quant_op.py b/test/legacy_test/test_fused_transpose_split_quant_op.py index edfea14fc1f35d..6c8604ba2ea876 100644 --- a/test/legacy_test/test_fused_transpose_split_quant_op.py +++ b/test/legacy_test/test_fused_transpose_split_quant_op.py @@ -17,8 +17,20 @@ import paddle -def fused_transpose_split_quant_ref(x, tokens_per_expert, pow_2_scales): +def dequant_ref( + fp8_tensor: paddle.Tensor, scale: paddle.Tensor, block_size: int = 128 +) -> paddle.Tensor: + """Helper function to dequantize fp8 tensor to bf16""" + expanded_scale = paddle.repeat_interleave(scale, repeats=128, axis=-1) + # Handle non-aligned cases by truncating + expanded_scale = expanded_scale[:, : fp8_tensor.shape[-1]] + return (fp8_tensor.astype('float32') * expanded_scale).astype('bfloat16') + + +def fused_transpose_split_quant_ref(x, xscale, tokens_per_expert, pow_2_scales): shape = x.shape + if x.dtype == paddle.float8_e4m3fn: + x = dequant_ref(x, xscale) x = x.reshape([shape[0] // 128, 128, shape[1]]) amax = x.astype('float32').abs().max(axis=1) @@ -37,43 +49,76 @@ def fused_transpose_split_quant_ref(x, tokens_per_expert, pow_2_scales): return out, scale -def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales): +def test_fused_transpose_split_quant( + tokens_per_expert, seq_len, pow_2_scales, using_fp8=False +): x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16') - x = paddle.clip(x, min=-50, max=50) + if using_fp8: + x = x.cast('float8_e4m3fn') + xscale = ( + paddle.randn( + [sum(tokens_per_expert), (seq_len + 127) // 128], dtype='float32' + ) + if using_fp8 + else None + ) + # x = paddle.clip(x, min=-50, max=50) out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant( - x, tokens_per_expert, pow_2_scales + x, xscale, tokens_per_expert, pow_2_scales ) out_ref, scale_ref = fused_transpose_split_quant_ref( - x, tokens_per_expert, pow_2_scales + x, xscale, tokens_per_expert, pow_2_scales ) for t, t_ref in zip(out, out_ref): - np.testing.assert_allclose(t.astype('float32'), t_ref.astype('float32')) + try: + np.testing.assert_allclose( + t.astype('float32'), t_ref.astype('float32') + ) + except AssertionError as e: + print("AssertionError", e) for t, t_ref in zip(scale, scale_ref): - np.testing.assert_allclose(t, t_ref) + try: + np.testing.assert_allclose(t, t_ref) + except AssertionError as e: + print("AssertionError", e) def run(): - test_fused_transpose_split_quant([0, 0], 1024, False) - test_fused_transpose_split_quant([128, 2 * 128], 0, True) - test_fused_transpose_split_quant([128], 1, False) - test_fused_transpose_split_quant([0, 128, 0, 2 * 128], 127, True) - test_fused_transpose_split_quant([3 * 128, 4 * 128, 5 * 128], 233, False) - test_fused_transpose_split_quant( - [24 * 128, 128, 50 * 128, 16 * 128], 2162, True - ) - test_fused_transpose_split_quant( - [7 * 128, 29 * 128, 3 * 128, 128 * 128, 13 * 128], 4000, False - ) - test_fused_transpose_split_quant( - [18 * 128, 5 * 128, 24 * 128, 128, 6 * 128, 0, 27 * 128, 7 * 128], - 7168, - True, - ) + fp8_choice = [True, False] + for using_fp8 in fp8_choice: + test_fused_transpose_split_quant( + [0, 0], 1024, False, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [128, 2 * 128], 0, True, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant([128], 1, False, using_fp8=using_fp8) + test_fused_transpose_split_quant( + [0, 128, 0, 2 * 128], 127, True, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [3 * 128, 4 * 128, 5 * 128], 233, False, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [24 * 128, 128, 50 * 128, 16 * 128], 2162, True, using_fp8=using_fp8 + ) + test_fused_transpose_split_quant( + [7 * 128, 29 * 128, 3 * 128, 128 * 128, 13 * 128], + 4000, + False, + using_fp8=using_fp8, + ) + test_fused_transpose_split_quant( + [18 * 128, 5 * 128, 24 * 128, 128, 6 * 128, 0, 27 * 128, 7 * 128], + 7168, + True, + using_fp8=using_fp8, + ) if __name__ == '__main__':