Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions csrc/moe/marlin_moe_wna16/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"

#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, int num_groups, int prob_m, int prob_n, \
int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce

namespace MARLIN_NAMESPACE_NAME {
Expand Down
28 changes: 2 additions & 26 deletions csrc/moe/marlin_moe_wna16/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ __global__ void Marlin(
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
Expand Down Expand Up @@ -273,7 +272,6 @@ __global__ void Marlin(
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
Expand Down Expand Up @@ -376,14 +374,6 @@ __global__ void Marlin(

// parallel: num valid moe blocks
int parallel = num_tokens_past_padded / moe_block_size;
int num_valid_blocks = parallel;
if (is_ep) {
for (int i = 0; i < parallel; i++) {
if (expert_ids_ptr[i] == -1) num_valid_blocks--;
}
}
int num_invalid_blocks = parallel - num_valid_blocks;
parallel = num_valid_blocks;

int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
Expand Down Expand Up @@ -538,22 +528,8 @@ __global__ void Marlin(
if (par_id >= parallel) return;

old_expert_id = expert_id;
if (num_invalid_blocks > 0) {
int skip_count = par_id;
for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) {
expert_id = expert_ids_ptr[i];
if (expert_id != -1) {
if (skip_count == 0) {
block_id = i;
break;
};
skip_count--;
};
}
} else {
block_id = par_id;
expert_id = expert_ids_ptr[block_id];
}
block_id = par_id;
expert_id = expert_ids_ptr[block_id];

if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[expert_id];
Expand Down
28 changes: 14 additions & 14 deletions csrc/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* perm, void* a_tmp, void* sorted_token_ids,
void* expert_ids, void* num_tokens_past_padded,
void* topk_weights, int moe_block_size, int num_experts,
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int top_k, bool mul_topk_weights, int prob_m, int prob_n,
int prob_k, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int blocks_per_sm, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
Expand Down Expand Up @@ -523,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on
}
Expand All @@ -541,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm(
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float, int64_t thread_k, int64_t thread_n,
Expand Down Expand Up @@ -855,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm(
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
mul_topk_weights, size_m, size_n, size_k, workspace.data_ptr(), a_type,
b_type, c_type, s_type, has_bias, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);

Expand All @@ -866,4 +866,4 @@ torch::Tensor moe_wna16_marlin_gemm(

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}
}
2 changes: 1 addition & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_type_id,"
"bool mul_topk_weights, int b_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float,"
Expand Down
3 changes: 0 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,6 @@ def moe_wna16_marlin_gemm(
moe_block_size: int,
top_k: int,
mul_topk_weights: bool,
is_ep: bool,
b_q_type: ScalarType,
size_m: int,
size_n: int,
Expand Down Expand Up @@ -2194,7 +2193,6 @@ def moe_wna16_marlin_gemm(
moe_block_size,
top_k,
mul_topk_weights,
is_ep,
b_q_type.id,
size_m,
size_n,
Expand Down Expand Up @@ -2256,7 +2254,6 @@ def moe_wna16_marlin_gemm_fake(
moe_block_size: int,
top_k: int,
mul_topk_weights: bool,
is_ep: bool,
b_q_type: ScalarType,
size_m: int,
size_n: int,
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def _fused_marlin_moe(
moe_block_size=block_size_m,
top_k=num_topk,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
size_n=2 * N,
Expand Down Expand Up @@ -187,7 +186,6 @@ def _fused_marlin_moe(
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * num_topk,
size_n=K,
Expand Down
Loading