Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions csrc/trtllm_moe_allreduce_fusion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,18 @@ void trtllm_moe_allreduce_fusion(

void trtllm_moe_finalize_allreduce_fusion(
TensorView allreduce_in, TensorView residual_in, TensorView norm_weight,
TensorView expanded_idx_to_permuted_idx, TensorView norm_out, TensorView residual_out,
bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size,
double const eps, Optional<TensorView> shared_expert_output,
Optional<TensorView> expert_scale_factor) {
TensorView expanded_idx_to_permuted_idx, Optional<TensorView> norm_out,
Optional<TensorView> residual_out, Optional<TensorView> quant_out,
Optional<TensorView> scale_out, bool launch_with_pdl, TensorView workspace,
int64_t const world_rank, int64_t const world_size, double const eps,
Optional<TensorView> shared_expert_output, Optional<TensorView> expert_scale_factor,
Optional<float> routed_scaling_factor) {
DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.dtype(), c_type, [&] {
MoeFinalizeAllReduceFusionParams<c_type> params;

int hidden_dim = residual_in.size(-1);
int top_k = expanded_idx_to_permuted_idx.size(-1);

params.quant_out = nullptr;
params.scale_out = nullptr;

params.nranks = static_cast<int>(world_size);
params.rank = static_cast<int>(world_rank);
// size: num_token * hidden_dim
Expand All @@ -122,8 +121,17 @@ void trtllm_moe_finalize_allreduce_fusion(
shared_expert_output.has_value() ? shared_expert_output.value().data_ptr() : nullptr;

// output tensors
params.norm_out = norm_out.data_ptr();
params.residual_out = residual_out.data_ptr();
params.residual_out = residual_out.has_value()
? reinterpret_cast<void*>(residual_out.value().data_ptr())
: nullptr;
params.norm_out =
norm_out.has_value() ? reinterpret_cast<void*>(norm_out.value().data_ptr()) : nullptr;
params.quant_out =
quant_out.has_value() ? reinterpret_cast<void*>(quant_out.value().data_ptr()) : nullptr;
params.scale_out =
scale_out.has_value() ? reinterpret_cast<void*>(scale_out.value().data_ptr()) : nullptr;
params.routed_scaling_factor =
routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0f;

auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl);
TVM_FFI_ICHECK(status == cudaSuccess)
Expand Down
25 changes: 20 additions & 5 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,22 +350,25 @@ def trtllm_moe_allreduce_fusion(

@register_custom_op(
"flashinfer::trtllm_moe_finalize_allreduce_fusion",
mutates_args=["residual_out", "norm_out"],
mutates_args=["residual_out", "norm_out", "quant_out", "scale_out"],
)
def trtllm_moe_finalize_allreduce_fusion(
allreduce_in: torch.Tensor,
residual_in: torch.Tensor,
norm_weight: torch.Tensor,
expanded_idx_to_permuted_idx: torch.Tensor,
norm_out: torch.Tensor,
residual_out: torch.Tensor,
norm_out: Optional[torch.Tensor],
residual_out: Optional[torch.Tensor],
quant_out: Optional[torch.Tensor],
scale_out: Optional[torch.Tensor],
launch_with_pdl: bool,
workspace: torch.Tensor,
world_rank: int,
world_size: int,
eps: float,
shared_expert_output: Optional[torch.Tensor],
expert_scale_factor: Optional[torch.Tensor],
routed_scaling_factor: Optional[float],
) -> None:
module.trtllm_moe_finalize_allreduce_fusion(
allreduce_in,
Expand All @@ -374,13 +377,16 @@ def trtllm_moe_finalize_allreduce_fusion(
expanded_idx_to_permuted_idx,
norm_out,
residual_out,
quant_out,
scale_out,
launch_with_pdl,
workspace,
world_rank,
world_size,
eps,
shared_expert_output,
expert_scale_factor,
routed_scaling_factor,
)

return SimpleNamespace(
Expand Down Expand Up @@ -1098,15 +1104,18 @@ def trtllm_moe_finalize_allreduce_fusion(
residual_in: torch.Tensor,
norm_weight: torch.Tensor,
expanded_idx_to_permuted_idx: torch.Tensor,
norm_out: torch.Tensor,
residual_out: torch.Tensor,
norm_out: Optional[torch.Tensor],
residual_out: Optional[torch.Tensor],
quant_out: Optional[torch.Tensor],
scale_out: Optional[torch.Tensor],
workspace_ptrs: torch.Tensor,
launch_with_pdl: bool,
world_rank: int,
world_size: int,
eps: float,
shared_expert_output: Optional[torch.Tensor],
expert_scale_factor: Optional[torch.Tensor],
routed_scaling_factor: Optional[float],
) -> None:
"""
Parameters:
Expand All @@ -1116,13 +1125,16 @@ def trtllm_moe_finalize_allreduce_fusion(
- expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k]
- norm_out: the norm output tensor. [token_num, hidden_dim]
- residual_out: the residual output tensor. [token_num, hidden_dim]
- quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4
- scale_out: the scale output tensor. [token_num // SF_VEC_SIZE, hidden_dim], fp16/bf16 -> fp4
- workspace_ptrs: the workspace pointers.
- launch_with_pdl: whether to launch with pdl.
- world_rank: the rank of the current process.
- world_size: the size of the process group.
- eps: the epsilon value.
- shared_expert_output: the shared expert output tensor. [token_num, hidden_dim]
- expert_scale_factor: the expert scale factor tensor. [token_num, top_k]
- routed_scaling_factor: the routed scaling factor.
"""

required_lamport_comm_size = allreduce_in.numel() * 2 * world_size
Expand All @@ -1140,11 +1152,14 @@ def trtllm_moe_finalize_allreduce_fusion(
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
norm_out=norm_out,
residual_out=residual_out,
quant_out=quant_out,
scale_out=scale_out,
workspace=workspace_ptrs,
launch_with_pdl=launch_with_pdl,
world_rank=world_rank,
world_size=world_size,
eps=eps,
shared_expert_output=shared_expert_output,
expert_scale_factor=expert_scale_factor,
routed_scaling_factor=routed_scaling_factor,
)
29 changes: 21 additions & 8 deletions include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ struct MoeFinalizeAllReduceFusionParams : public AllReduceFusionParams<T> {
// [num_tokens, top_k]
int32_t* expanded_idx_to_permuted_idx = nullptr;
// allreduce_in [maxPermutedPaddedCount, hidden_dim]
float routed_scaling_factor = 1.0f;
};

template <int NRanks>
Expand Down Expand Up @@ -1277,6 +1278,8 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport(

int top_k = params.top_k;
bool use_scale_factor = params.expert_scale_factor != nullptr;
float routed_scaling_factor = params.routed_scaling_factor;
bool use_routed_scaling_factor = routed_scaling_factor != 1.0f;

// Persistent Kernel
// Each cluster iterate through all token it need to handle
Expand All @@ -1297,20 +1300,28 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport(

int thread_offset_across_token =
permuted_idx * params.hidden_dim + thread_offset_within_token;
float block_scale = 1.0;
if (use_scale_factor) {
block_scale =
static_cast<float>(static_cast<ScaleType*>(params.expert_scale_factor)[expanded_idx]);
}

vec_t<T, VEC_SIZE> permuted_data;
permuted_data.load(reinterpret_cast<T*>(params.allreduce_in) + thread_offset_across_token);

// * acc += scale(data)
if (use_scale_factor) {
float block_scale =
static_cast<float>(static_cast<ScaleType*>(params.expert_scale_factor)[expanded_idx]);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
// assume computation is done in ScaleType
accumulator[i] += static_cast<T>(static_cast<float>(permuted_data[i]) * block_scale);
}
} else {
accumulator = vec_add<T, VEC_SIZE>(accumulator, permuted_data);
}
}

// Apply the global routed scaling once after accumulating all routed experts.
if (use_routed_scaling_factor) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
// assume computation is done in ScaleType
accumulator[i] += static_cast<T>(static_cast<float>(permuted_data[i]) * block_scale);
accumulator[i] = static_cast<T>(static_cast<float>(accumulator[i]) * routed_scaling_factor);
}
}

Expand Down Expand Up @@ -1474,6 +1485,8 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams<T>
"allreduce_in, expanded_idx_to_permuted_idx and top_k must be set");
FLASHINFER_CHECK(params.size % params.hidden_dim == 0, "size must be a multiple of hidden_dim");
FLASHINFER_CHECK(params.hidden_dim % VEC_SIZE == 0, "hidden_dim must be a multiple of VEC_SIZE");
FLASHINFER_CHECK(params.residual_out || params.norm_out || params.quant_out,
"at least one of residual_out, norm_out, quant_out must be set");

auto status = DISPATCH_MOEFINALIZEREDUCTION(
params.nranks, params.residual_out, params.rms_gamma, params.quant_out, N_RANKS, RES, RMS,
Expand Down
17 changes: 16 additions & 1 deletion tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):

norm_out = torch.empty_like(residual)
residual_out = torch.empty_like(residual)
seq_len, hidden_size = norm_out.shape
quant_out = torch.empty(
seq_len * hidden_size // 4, dtype=dtype, device=device
)
scale_out = torch.empty(
seq_len * hidden_size // SF_VEC_SIZE, dtype=dtype, device=device
)
routed_scaling_factor = 2.5

# == Run kernel ==
torch.cuda.synchronize()
Expand All @@ -117,6 +125,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
expert_scale_factor=scale,
norm_out=norm_out,
residual_out=residual_out,
quant_out=quant_out,
scale_out=scale_out,
routed_scaling_factor=routed_scaling_factor,
)
torch.cuda.current_stream().wait_stream(s)

Expand All @@ -138,6 +149,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
expert_scale_factor=scale,
norm_out=norm_out,
residual_out=residual_out,
quant_out=quant_out,
scale_out=scale_out,
routed_scaling_factor=routed_scaling_factor,
)

# replay
Expand All @@ -148,7 +162,8 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
# == Calculate reference output ==
expert_reduction = torch.sum(
fc2_output_clone[expanded_idx_to_permuted_idx]
* scale.unsqueeze(-1),
* scale.unsqueeze(-1)
* routed_scaling_factor,
dim=1,
)

Expand Down
Loading