Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2421,16 +2421,17 @@ void FusedMultiTransformerInt8InferMeta(
}

void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
const MetaTensor& input_scales,
const IntArray& tokens_per_expert,
bool pow_2_scales,
std::vector<MetaTensor*> outs,
std::vector<MetaTensor*> scales) {
PADDLE_ENFORCE_EQ(
x.dtype(),
DataType::BFLOAT16,
common::errors::InvalidArgument(
"The dtype of Input(x) must be BFLOAT16, but received %s",
x.dtype()));
x.dtype() == DataType::BFLOAT16 || x.dtype() == DataType::FLOAT8_E4M3FN,
true,
common::errors::InvalidArgument("The dtype of Input(x) must be BFLOAT16 "
"or FLOAT8_E4M3FN, but received %s",
x.dtype()));

auto x_dims = x.dims();

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ void FusedMultiTransformerInt8InferMeta(
MetaTensor* out);

void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
const MetaTensor& input_scales,
const IntArray& tokens_per_expert,
bool pow_2_scales,
std::vector<MetaTensor*> outs,
Expand Down
146 changes: 94 additions & 52 deletions paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,62 @@ struct __align__(sizeof(T) * VecSize) VecType {
}
};

template <int VecSize>
__device__ void BlockLoad(const phi::bfloat16* input,
template <typename InT, int VecSize>
__device__ void BlockLoad(const InT* input,
const float* input_scales,
__nv_bfloat16 x[8][4],
size_t K) {
size_t K,
size_t k_scaled) {
constexpr bool need_dequant = std::is_same_v<InT, phi::dtype::float8_e4m3fn>;

#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
size_t off_m = blockIdx.x * size_t(128) + threadIdx.y + i * 16;
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
size_t offset = off_m * K + off_k;
const uint32_t local_off_M = threadIdx.y + i * 16;
const uint32_t off_m = blockIdx.x * 128 + local_off_M;
const uint32_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
const size_t offset = off_m * K + off_k;

float scale;
if constexpr (need_dequant) {
const uint32_t m_base = blockIdx.x * 128;
const uint32_t m_stride = k_scaled;
scale = input_scales[off_m * m_stride + blockIdx.y];
}

#pragma unroll
for (uint32_t j = 0; j < 4; j += VecSize) {
if (off_k + j * 32 < K) {
size_t idx = offset + j * 32;
using LoadT = VecType<__nv_bfloat16, VecSize>;
LoadT data = *reinterpret_cast<const LoadT*>(input + idx);
for (uint32_t k = 0; k < VecSize; k++) {
x[i][j + k] = data[k];
const size_t idx = offset + j * 32;
using LoadT = VecType<InT, VecSize>;
LoadT data = *reinterpret_cast<const LoadT*>(input + idx);
#pragma unroll
for (uint32_t k = 0; k < VecSize; k++) {
if constexpr (need_dequant) {
x[i][j + k] = __float2bfloat16(static_cast<float>(data[k]) * scale);
} else {
x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k]));
}
}
}
}
}

template <bool Pow2Scales>
__device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
float col_scale[128],
float scales[128],
__nv_bfloat16* shm) {
// reduce [(8), 16, 32, 4] => [16, 32, 4]
__nv_bfloat16 warp_max[4];
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
__nv_bfloat16 t = BF16_ABS(x[i][j]);
const __nv_bfloat16 t = BF16_ABS(x[i][j]);
warp_max[j] = i == 0 ? t : BF16_MAX(warp_max[j], t);
}
}

// reduce [(16), 32, 4] => [8, 32, 4]
if (threadIdx.y >= 8) {
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
shm[(threadIdx.y - 8) * 128 + threadIdx.x + j * 32] = warp_max[j];
}
Expand All @@ -75,8 +94,9 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
// reduce [(8), 32, 4] => [32, 4]
for (uint32_t offset = 8; offset > 0; offset /= 2) {
if (threadIdx.y < offset) {
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
__nv_bfloat16 other =
const __nv_bfloat16 other =
offset == 8
? warp_max[j]
: shm[(threadIdx.y + offset) * 128 + threadIdx.x + j * 32];
Expand All @@ -85,7 +105,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
if (offset > 1) {
shm[threadIdx.y * 128 + threadIdx.x + j * 32] = next_val;
} else {
col_scale[threadIdx.x + j * 32] =
scales[threadIdx.x + j * 32] =
ComputeScale<__nv_bfloat16, __nv_fp8_e4m3, Pow2Scales>(
static_cast<float>(next_val), 0.0f);
}
Expand All @@ -98,7 +118,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
template <typename OutT, int VecSize>
__device__ void BlockStoreScale(float* scale,
size_t off_m,
float col_scale[128],
float scales[128],
size_t K) {
if (threadIdx.y < 4) {
uint32_t off = threadIdx.y * 32 + threadIdx.x;
Expand All @@ -107,10 +127,10 @@ __device__ void BlockStoreScale(float* scale,
} else if constexpr (VecSize == 2) {
off = (off / 64) * 64 + (off % 2) * 32 + (off % 64) / 2;
}
float scale_out = 1.0f / col_scale[off];
size_t idx_y = blockIdx.x - off_m / 128;
size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x;
size_t idx = idx_y * K + idx_x;
float scale_out = 1.0f / scales[off];
const size_t idx_y = blockIdx.x - off_m / 128;
const size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x;
const size_t idx = idx_y * K + idx_x;
if (idx_x < K) {
scale[idx] = scale_out;
}
Expand All @@ -123,14 +143,16 @@ __device__ void BlockStoreOut(OutT* out,
size_t cur_tokens,
const OutT shm[128][129],
size_t K) {
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4;
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16;
size_t idx = idx_k * cur_tokens + (idx_m - off_m);
const size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4;
const size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16;
const size_t idx = idx_k * cur_tokens + (idx_m - off_m);

if (idx_k < K) {
using StoreT = VecType<OutT, VecSize>;
StoreT data;
#pragma unroll
for (uint32_t j = 0; j < VecSize; j++) {
data[j] = shm[i * 16 + threadIdx.y][threadIdx.x * 4 + j];
}
Expand All @@ -139,23 +161,27 @@ __device__ void BlockStoreOut(OutT* out,
}
}

template <typename OutT, bool Pow2Scales, int VecSize>
template <typename InT, typename OutT, bool Pow2Scales, int VecSize>
__global__ void __launch_bounds__(512)
FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ input,
FusedTransposeSplitQuantKernel(const InT* __restrict__ input,
const float* __restrict__ input_scales,
int64_t* __restrict__ meta,
size_t num_experts,
size_t K) {
size_t K,
size_t k_scaled) {
__shared__ OutT shm[128][129];
__shared__ size_t expert_info[2];
__shared__ float scales[128]; // May be reused? Is it worthy?

int64_t* tokens_per_expert = meta;
OutT** out_ptrs = reinterpret_cast<OutT**>(meta + num_experts);
float** scale_ptrs = reinterpret_cast<float**>(meta + num_experts * 2);

// 1. Load 128x128 elements from input
__nv_bfloat16 x[8][4];
BlockLoad<VecSize>(input, x, K);
BlockLoad<InT, VecSize>(input, input_scales, x, K, k_scaled);

// 2. Get expert index and offset of the current block
__shared__ size_t expert_info[2];
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t idx_m = blockIdx.x * size_t(128);
size_t off_m = 0, next_off_m = 0;
Expand All @@ -172,21 +198,23 @@ __global__ void __launch_bounds__(512)
}

// 3. Calculate scale along the column
__shared__ float col_scale[128];
BlockColumnScale<Pow2Scales>(
x, col_scale, reinterpret_cast<__nv_bfloat16*>(shm));
x, scales, reinterpret_cast<__nv_bfloat16*>(shm));

// 4. Store scale
const size_t expert_idx = expert_info[0];
const size_t off_m = expert_info[1];
BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, col_scale, K);
BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, scales, K);

// 5. Scale x and save into shared memory with transposed layout
// 5. Scale x and save into shared memory with transposed layout
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
#pragma unroll
for (uint32_t j = 0; j < 4; j += VecSize) {
#pragma unroll
for (uint32_t k = 0; k < VecSize; k++) {
float x_fp32 = static_cast<float>(x[i][j + k]);
float x_scaled = x_fp32 * col_scale[threadIdx.x + (j + k) * 32];
float x_scaled = x_fp32 * scales[threadIdx.x + (j + k) * 32];
shm[threadIdx.x * VecSize + j * 32 + k][i * 16 + threadIdx.y] =
static_cast<OutT>(x_scaled);
}
Expand All @@ -204,10 +232,11 @@ template <typename T, typename Context>
void FusedTransposeSplitQuantKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& input_scales,
const std::vector<int64_t>& tokens_per_expert,
bool pow_2_scales,
std::vector<DenseTensor*> outs,
std::vector<DenseTensor*> scales) {
std::vector<DenseTensor*> output_scales) {
auto x_dims = x.dims();
const int64_t M = x_dims[0];
const int64_t K = x_dims[1];
Expand All @@ -221,8 +250,8 @@ void FusedTransposeSplitQuantKernel(
if (outs[i] != nullptr) {
dev_ctx.template Alloc<phi::dtype::float8_e4m3fn>(outs[i]);
}
if (scales[i] != nullptr) {
dev_ctx.template Alloc<float>(scales[i]);
if (output_scales[i] != nullptr) {
dev_ctx.template Alloc<float>(output_scales[i]);
}
}

Expand All @@ -245,8 +274,8 @@ void FusedTransposeSplitQuantKernel(

for (size_t i = 0; i < num_experts; i++) {
meta_ptr[num_experts * 2 + i] =
scales[i] != nullptr
? reinterpret_cast<int64_t>(scales[i]->data<float>())
output_scales[i] != nullptr
? reinterpret_cast<int64_t>(output_scales[i]->data<float>())
: 0;
}

Expand All @@ -255,23 +284,35 @@ void FusedTransposeSplitQuantKernel(

auto stream = dev_ctx.stream();

dim3 grid(M / 128, (K + 127) / 128);
// pre-compute on CPU to reduce size_t division cost in kernel
const size_t k_scaled = (K + 127) / 128;
dim3 grid(M / 128, k_scaled);
dim3 block(32, 16);

#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \
FusedTransposeSplitQuantKernel<phi::dtype::float8_e4m3fn, \
POW_2_SCALES, \
VEC_SIZE> \
<<<grid, block, 0, stream>>>(x.data<phi::dtype::bfloat16>(), \
meta_gpu.data<int64_t>(), \
num_experts, \
K);
#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type
#define LAUNCH_KERNEL(T, POW_2_SCALES, VEC_SIZE) \
FusedTransposeSplitQuantKernel<T, \
phi::dtype::float8_e4m3fn, \
POW_2_SCALES, \
VEC_SIZE><<<grid, block, 0, stream>>>( \
x.data<T>(), \
input_scales ? input_scales.get_ptr()->data<float>() : nullptr, \
meta_gpu.data<int64_t>(), \
num_experts, \
K, \
k_scaled);
#define DISPATCH_DATATYPE(POW_2_SCALES, VEC_SIZE) \
if (DTYPE_CASE(x.dtype(), BFLOAT16)) { \
LAUNCH_KERNEL(phi::bfloat16, POW_2_SCALES, VEC_SIZE); \
} else if (DTYPE_CASE(x.dtype(), FLOAT8_E4M3FN)) { \
LAUNCH_KERNEL(phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE); \
}

#define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \
if (pow_2_scales) { \
LAUNCH_KERNEL(true, VEC_SIZE); \
DISPATCH_DATATYPE(true, VEC_SIZE); \
} else { \
LAUNCH_KERNEL(false, VEC_SIZE); \
DISPATCH_DATATYPE(false, VEC_SIZE); \
}

if (K % 4 == 0) {
Expand All @@ -296,7 +337,8 @@ PD_REGISTER_KERNEL(fused_transpose_split_quant,
double,
int,
int64_t,
phi::dtype::bfloat16) {
phi::dtype::bfloat16,
phi::dtype::float8_e4m3fn) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT8_E4M3FN);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
}
3 changes: 2 additions & 1 deletion paddle/phi/ops/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -916,12 +916,13 @@
support_dygraph_mode : true

- op: fused_transpose_split_quant
args: (Tensor x, IntArray tokens_per_expert, bool pow_2_scales=false)
args: (Tensor x, Tensor input_scales, IntArray tokens_per_expert, bool pow_2_scales=false)
output: Tensor[](out){tokens_per_expert.size()}, Tensor[](scales){tokens_per_expert.size()}
infer_meta:
func: FusedTransposeSplitQuantInferMeta
kernel:
func: fused_transpose_split_quant
optional: input_scales
support_dygraph_mode : true

- op: fused_weighted_swiglu_act_quant
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/incubate/nn/functional/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def fused_swiglu_weighted_bwd(
return _C_ops.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs)


def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False):
def fused_transpose_split_quant(
x, input_scales, tokens_per_expert, pow_2_scales=False
):
"""
Applies fused transpose, split, and quantization operation for Mixture of Experts (MoE) models.

Expand Down Expand Up @@ -215,7 +217,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False):
>>> x = paddle.randn([384, 512], dtype='bfloat16')
>>> x = paddle.clip(x, min=-50, max=50)
>>> tokens_per_expert = [128, 128, 128]
>>> outs, scales = F.fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=True)
>>> outs, scales = F.fused_transpose_split_quant(x,None, tokens_per_expert, pow_2_scales=True)
>>> print(outs[0].shape)
[512, 128]
>>> print(scales[0].shape)
Expand All @@ -228,7 +230,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False):

if in_dynamic_or_pir_mode():
return _C_ops.fused_transpose_split_quant(
x, tokens_per_expert, pow_2_scales
x, input_scales, tokens_per_expert, pow_2_scales
)


Expand Down
Loading