Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 114 additions & 108 deletions python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "../marlin/marlin.cuh"
#include "../marlin/marlin_dtypes.cuh"

#include <type_traits>

#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
Expand Down Expand Up @@ -355,6 +357,7 @@ __global__ void Marlin(
constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8;
constexpr bool is_int_type =
w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop = w_type == host::kFE4M3fn || w_type == host::kFE2M1f && s_type == host::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
Expand All @@ -368,7 +371,7 @@ __global__ void Marlin(
static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == host::kFE2M1f ? 16 : 8);
const int scales_expert_stride = prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4);
const int b_bias_expert_stride = prob_n / 8;
Expand Down Expand Up @@ -439,52 +442,75 @@ __global__ void Marlin(
locks_off = (iters * blockIdx.x) / k_tiles - 1;
}

int prob_m_top_k = prob_m * top_k;
// read moe block data given block_id
// block_sorted_ids / block_num_valid_tokens / block_topk_weights
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;

cp_async4_pred(
sh_block_sorted_ids_int4 + threadIdx.x,
reinterpret_cast<const int4*>(sorted_token_ids_ptr) +
(block_id * moe_block_size / 4 + threadIdx.x),
threadIdx.x < moe_block_size / 4);

cp_async_fence();
cp_async_wait<0>();

__syncthreads();

if (threadIdx.x >= threads - 32) {
constexpr int size_per_thread = div_ceil(moe_block_size, 32);
int lane_id = threadIdx.x - (threads - 32);

int local_count = 0;
#pragma unroll
for (int i = 0; i < moe_block_size / 4; i++) {
int4 sorted_token_ids_int4 =
reinterpret_cast<const int4*>(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4);
#pragma unroll
for (int j = 0; j < 4; j++) {
if (sorted_token_ids[j] >= prob_m * top_k) {
block_num_valid_tokens = i * 4 + j;
break;
for (int i = 0; i < size_per_thread; i++) {
int j = lane_id * size_per_thread + i;
if (j < moe_block_size) {
int idx = sh_block_sorted_ids[j];
if (idx < prob_m_top_k) local_count++;
}
}
if (block_num_valid_tokens != moe_block_size) break;
}

__syncthreads();
int tid4 = threadIdx.x / 4;
if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) {
sh_block_sorted_ids_int4[tid4] =
reinterpret_cast<const int4*>(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if constexpr (moe_block_size >= 16)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
if constexpr (moe_block_size >= 8)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
if constexpr (moe_block_size >= 4)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
if constexpr (moe_block_size >= 2)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);

local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
block_num_valid_tokens = local_count;
#else
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
#endif

#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] = sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (lane_id == 0) reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
}

if (threadIdx.x < moe_block_size) {
int idx = sh_block_sorted_ids[threadIdx.x];
sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k;

if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
// idx = idx < block_num_valid_tokens ? idx : 0;
if (idx < block_num_valid_tokens) {
if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) {
sh_block_topk_weights[idx] =
__hmul2(global_scale, Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] =
Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
}
}
idx = idx < prob_m_top_k ? idx : 0;
scalar_t topk_weight_tmp = Dtype::float2num(topk_weights_ptr[idx]);
if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) {
sh_block_topk_weights[threadIdx.x] =
__hmul2(global_scale, Dtype::num2num2(topk_weight_tmp));
} else {
sh_block_topk_weights[threadIdx.x] = Dtype::num2num2(topk_weight_tmp);
}
}
}

__syncthreads();

block_num_valid_tokens = reinterpret_cast<int*>(sh_new)[0];
__syncthreads();
};

Expand Down Expand Up @@ -626,11 +652,10 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;

// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1)
: 1;
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;

Expand Down Expand Up @@ -681,13 +706,15 @@ __global__ void Marlin(
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
}
}
auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
bool s_sh_wr_pred = threadIdx.x < s_sh_stage;

// Zero-points
int zp_gl_rd;
Expand All @@ -705,15 +732,7 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;

s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;

} else if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;
Expand Down Expand Up @@ -907,43 +926,21 @@ __global__ void Marlin(
} else {
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;

if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta * s_tb_groups;
}
}

if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;

if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta * zp_tb_groups;
}
}
}
Expand Down Expand Up @@ -1021,35 +1018,33 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[0])[0];
constexpr int g = group_blocks / thread_k_blocks;
if (pipe % g == 0) {
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[0])[0];
}
}
} else {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;

int warp_row = warp_id / n_warps;

int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);

int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1));
int cur_group_id = k_blocks / group_blocks;

int4* sh_s_stage = sh_s + s_sh_stage * pipe;

if constexpr (w_type_id != host::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2];
reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
}
}
Expand Down Expand Up @@ -1243,17 +1238,18 @@ __global__ void Marlin(
}
}

// Commented out FP4/FP8 scale dequantization since we don't generate
// kFE2M1f kernels to reduce compilation time
// if constexpr (w_type == host::kFE2M1f) {
// int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
// int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
//
// dequant_fp8_scales<scalar_t2, s_type_id>(
// s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
// dequant_fp8_scales<scalar_t2, s_type_id>(
// s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
// }
// FP4/FP8 scale dequantization (E4M3 for NVFP4 and E8M0 for MXFP4).
if constexpr ((s_type == host::kFE4M3fn || s_type == host::kFE8M0fnu) &&
!(std::is_same<scalar_t2, half2>::value &&
s_type == host::kFE8M0fnu)) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}

// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
Expand Down Expand Up @@ -1882,8 +1878,18 @@ __global__ void Marlin(
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
}
}
start_pipes();
}
Expand Down
Loading
Loading