From c2187b436020b5b82bc8f9c8de68cf8e76557482 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Fri, 3 Apr 2026 07:08:22 +0000 Subject: [PATCH 1/6] Add routed scaling + quant/scale support to moe finalize allreduce kernel --- csrc/trtllm_moe_allreduce_fusion.cu | 26 +++++---- flashinfer/comm/trtllm_ar.py | 26 ++++++--- .../comm/trtllm_moe_allreduce_fusion.cuh | 8 ++- ...st_trtllm_moe_allreduce_fusion_finalize.py | 54 ++++++++++++++++--- 4 files changed, 91 insertions(+), 23 deletions(-) diff --git a/csrc/trtllm_moe_allreduce_fusion.cu b/csrc/trtllm_moe_allreduce_fusion.cu index ac1ce1714d..f6690a0096 100644 --- a/csrc/trtllm_moe_allreduce_fusion.cu +++ b/csrc/trtllm_moe_allreduce_fusion.cu @@ -83,19 +83,18 @@ void trtllm_moe_allreduce_fusion( void trtllm_moe_finalize_allreduce_fusion( TensorView allreduce_in, TensorView residual_in, TensorView norm_weight, - TensorView expanded_idx_to_permuted_idx, TensorView norm_out, TensorView residual_out, - bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size, - double const eps, Optional shared_expert_output, - Optional expert_scale_factor) { + TensorView expanded_idx_to_permuted_idx, Optional norm_out, + Optional residual_out, Optional quant_out, + Optional scale_out, bool launch_with_pdl, TensorView workspace, + int64_t const world_rank, int64_t const world_size, double const eps, + Optional shared_expert_output, Optional expert_scale_factor, + Optional routed_scaling_factor) { DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.dtype(), c_type, [&] { MoeFinalizeAllReduceFusionParams params; int hidden_dim = residual_in.size(-1); int top_k = expanded_idx_to_permuted_idx.size(-1); - params.quant_out = nullptr; - params.scale_out = nullptr; - params.nranks = static_cast(world_size); params.rank = static_cast(world_rank); // size: num_token * hidden_dim @@ -122,8 +121,17 @@ void trtllm_moe_finalize_allreduce_fusion( shared_expert_output.has_value() ? shared_expert_output.value().data_ptr() : nullptr; // output tensors - params.norm_out = norm_out.data_ptr(); - params.residual_out = residual_out.data_ptr(); + params.residual_out = residual_out.has_value() + ? reinterpret_cast(residual_out.value().data_ptr()) + : nullptr; + params.norm_out = + norm_out.has_value() ? reinterpret_cast(norm_out.value().data_ptr()) : nullptr; + params.quant_out = + quant_out.has_value() ? reinterpret_cast(quant_out.value().data_ptr()) : nullptr; + params.scale_out = + scale_out.has_value() ? reinterpret_cast(scale_out.value().data_ptr()) : nullptr; + params.routed_scaling_factor = + routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0f; auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl); TVM_FFI_ICHECK(status == cudaSuccess) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 4392468e80..3034a185e7 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -348,22 +348,24 @@ def trtllm_moe_allreduce_fusion( @register_custom_op( "flashinfer::trtllm_moe_finalize_allreduce_fusion", - mutates_args=["residual_out", "norm_out"], + mutates_args=["residual_out", "norm_out", "quant_out", "scale_out"], ) def trtllm_moe_finalize_allreduce_fusion( allreduce_in: torch.Tensor, residual_in: torch.Tensor, norm_weight: torch.Tensor, expanded_idx_to_permuted_idx: torch.Tensor, - norm_out: torch.Tensor, - residual_out: torch.Tensor, - launch_with_pdl: bool, + norm_out: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + quant_out: Optional[torch.Tensor], + scale_out: Optional[torch.Tensor], launch_with_pdl: bool, workspace: torch.Tensor, world_rank: int, world_size: int, eps: float, shared_expert_output: Optional[torch.Tensor], expert_scale_factor: Optional[torch.Tensor], + routed_scaling_factor: Optional[float], ) -> None: module.trtllm_moe_finalize_allreduce_fusion( allreduce_in, @@ -372,6 +374,8 @@ def trtllm_moe_finalize_allreduce_fusion( expanded_idx_to_permuted_idx, norm_out, residual_out, + quant_out, + scale_out, launch_with_pdl, workspace, world_rank, @@ -379,6 +383,7 @@ def trtllm_moe_finalize_allreduce_fusion( eps, shared_expert_output, expert_scale_factor, + routed_scaling_factor, ) return SimpleNamespace( @@ -1089,8 +1094,10 @@ def trtllm_moe_finalize_allreduce_fusion( residual_in: torch.Tensor, norm_weight: torch.Tensor, expanded_idx_to_permuted_idx: torch.Tensor, - norm_out: torch.Tensor, - residual_out: torch.Tensor, + norm_out: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + quant_out: Optional[torch.Tensor], + scale_out: Optional[torch.Tensor], workspace_ptrs: torch.Tensor, launch_with_pdl: bool, world_rank: int, @@ -1098,6 +1105,7 @@ def trtllm_moe_finalize_allreduce_fusion( eps: float, shared_expert_output: Optional[torch.Tensor], expert_scale_factor: Optional[torch.Tensor], + routed_scaling_factor: Optional[float], ) -> None: """ Parameters: @@ -1107,6 +1115,8 @@ def trtllm_moe_finalize_allreduce_fusion( - expanded_idx_to_permuted_idx: the expanded index to permuted index tensor. [token_num, top_k] - norm_out: the norm output tensor. [token_num, hidden_dim] - residual_out: the residual output tensor. [token_num, hidden_dim] + - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 + - scale_out: the scale output tensor. [token_num // SF_VEC_SIZE, hidden_dim], fp16/bf16 -> fp4 - workspace_ptrs: the workspace pointers. - launch_with_pdl: whether to launch with pdl. - world_rank: the rank of the current process. @@ -1114,6 +1124,7 @@ def trtllm_moe_finalize_allreduce_fusion( - eps: the epsilon value. - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k] + - routed_scaling_factor: the routed scaling factor. """ required_lamport_comm_size = allreduce_in.numel() * 2 * world_size @@ -1131,6 +1142,8 @@ def trtllm_moe_finalize_allreduce_fusion( expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, norm_out=norm_out, residual_out=residual_out, + quant_out=quant_out, + scale_out=scale_out, workspace=workspace_ptrs, launch_with_pdl=launch_with_pdl, world_rank=world_rank, @@ -1138,4 +1151,5 @@ def trtllm_moe_finalize_allreduce_fusion( eps=eps, shared_expert_output=shared_expert_output, expert_scale_factor=expert_scale_factor, + routed_scaling_factor=routed_scaling_factor, ) diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index 143e25de9c..45ec6a2470 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -706,6 +706,7 @@ struct MoeFinalizeAllReduceFusionParams : public AllReduceFusionParams { // [num_tokens, top_k] int32_t* expanded_idx_to_permuted_idx = nullptr; // allreduce_in [maxPermutedPaddedCount, hidden_dim] + float routed_scaling_factor = 1.0f; }; template @@ -1297,9 +1298,9 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( int thread_offset_across_token = permuted_idx * params.hidden_dim + thread_offset_within_token; - float block_scale = 1.0; + float block_scale = params.routed_scaling_factor; if (use_scale_factor) { - block_scale = + block_scale *= static_cast(static_cast(params.expert_scale_factor)[expanded_idx]); } @@ -1474,6 +1475,9 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams "allreduce_in, expanded_idx_to_permuted_idx and top_k must be set"); FLASHINFER_CHECK(params.size % params.hidden_dim == 0, "size must be a multiple of hidden_dim"); FLASHINFER_CHECK(params.hidden_dim % VEC_SIZE == 0, "hidden_dim must be a multiple of VEC_SIZE"); + FLASHINFER_CHECK( + params.moe_allreduce_out || params.residual_out || params.norm_out || params.quant_out, + "at least one of moe_allreduce_out, residual_out, norm_out, quant_out must be set"); auto status = DISPATCH_MOEFINALIZEREDUCTION( params.nranks, params.residual_out, params.rms_gamma, params.quant_out, N_RANKS, RES, RMS, diff --git a/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py b/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py index 684fccdc50..9e8505e82b 100644 --- a/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py +++ b/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py @@ -1,5 +1,8 @@ import multiprocessing as mp +import os import socket +import sys +from pathlib import Path from typing import Any import numpy as np @@ -95,6 +98,14 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): norm_out = torch.empty_like(residual) residual_out = torch.empty_like(residual) + seq_len, hidden_size = norm_out.shape + quant_out = torch.empty( + seq_len * hidden_size // 4, dtype=dtype, device=device + ) + scale_out = torch.empty( + seq_len * hidden_size // SF_VEC_SIZE, dtype=dtype, device=device + ) + routed_scaling_factor = 2.5 # == Run kernel == torch.cuda.synchronize() @@ -117,6 +128,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): expert_scale_factor=scale, norm_out=norm_out, residual_out=residual_out, + quant_out=quant_out, + scale_out=scale_out, + routed_scaling_factor=routed_scaling_factor, ) torch.cuda.current_stream().wait_stream(s) @@ -138,6 +152,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): expert_scale_factor=scale, norm_out=norm_out, residual_out=residual_out, + quant_out=quant_out, + scale_out=scale_out, + routed_scaling_factor=routed_scaling_factor, ) # replay @@ -148,7 +165,8 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): # == Calculate reference output == expert_reduction = torch.sum( fc2_output_clone[expanded_idx_to_permuted_idx] - * scale.unsqueeze(-1), + * scale.unsqueeze(-1) + * routed_scaling_factor, dim=1, ) @@ -240,6 +258,26 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): dist.destroy_process_group(group=group) +def _require_cuda_usable(world_size: int) -> None: + """Skip if PyTorch cannot fully initialize CUDA (driver/toolkit skew, etc.).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + try: + torch.cuda.set_device(0) + torch.empty(1, device="cuda", dtype=torch.float32) + torch.cuda.synchronize() + except RuntimeError as exc: + pytest.skip( + "PyTorch cannot initialize CUDA (match torch CUDA build to your driver, " + f"e.g. cu128 wheel for a 12.8-capable driver). Original error: {exc}" + ) + available = torch.cuda.device_count() + if world_size > available: + pytest.skip( + f"world_size {world_size} is greater than available_gpus {available}" + ) + + def get_open_port() -> int: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -255,6 +293,14 @@ def multi_process_parallel( world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) + # Spawn copies the parent's sys.path; workers unpickle targets from tests.comm.*. + repo_root = str(Path(__file__).resolve().parent.parent.parent) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + prev_pp = os.environ.get("PYTHONPATH", "") + os.environ["PYTHONPATH"] = ( + repo_root + (os.pathsep + prev_pp if prev_pp else "") + ) procs = [] distributed_init_port = get_open_port() @@ -276,12 +322,8 @@ def multi_process_parallel( def test_trtllm_moe_finalize_allreduce_fusion(world_size, dtype): np.random.seed(42) torch.manual_seed(42) + _require_cuda_usable(world_size) torch.cuda.manual_seed_all(42) - available_gpus = torch.cuda.device_count() - if world_size > available_gpus: - pytest.skip( - f"world_size {world_size} is greater than available_gpus {available_gpus}" - ) print(f"Running test for world_size={world_size}") # generate shared random input tensor across all ranks From 895c6e26277912864ce9f452ff5ebc9b8cd6f722 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Fri, 3 Apr 2026 07:11:38 +0000 Subject: [PATCH 2/6] pre-commit --- flashinfer/comm/trtllm_ar.py | 3 ++- include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 3034a185e7..253ff7e2d8 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -358,7 +358,8 @@ def trtllm_moe_finalize_allreduce_fusion( norm_out: Optional[torch.Tensor], residual_out: Optional[torch.Tensor], quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], launch_with_pdl: bool, + scale_out: Optional[torch.Tensor], + launch_with_pdl: bool, workspace: torch.Tensor, world_rank: int, world_size: int, diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index 45ec6a2470..6e2deaee9a 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -1476,8 +1476,8 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams FLASHINFER_CHECK(params.size % params.hidden_dim == 0, "size must be a multiple of hidden_dim"); FLASHINFER_CHECK(params.hidden_dim % VEC_SIZE == 0, "hidden_dim must be a multiple of VEC_SIZE"); FLASHINFER_CHECK( - params.moe_allreduce_out || params.residual_out || params.norm_out || params.quant_out, - "at least one of moe_allreduce_out, residual_out, norm_out, quant_out must be set"); + params.moe_allreduce_out || params.residual_out || params.norm_out || params.quant_out, + "at least one of moe_allreduce_out, residual_out, norm_out, quant_out must be set"); auto status = DISPATCH_MOEFINALIZEREDUCTION( params.nranks, params.residual_out, params.rms_gamma, params.quant_out, N_RANKS, RES, RMS, From b45d317ade44ad991c58d5da8ba674b2fb2dc93b Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Fri, 3 Apr 2026 07:11:38 +0000 Subject: [PATCH 3/6] pre-commit --- ...st_trtllm_moe_allreduce_fusion_finalize.py | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py b/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py index 9e8505e82b..54aa94fa66 100644 --- a/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py +++ b/tests/comm/test_trtllm_moe_allreduce_fusion_finalize.py @@ -1,8 +1,5 @@ import multiprocessing as mp -import os import socket -import sys -from pathlib import Path from typing import Any import numpy as np @@ -258,26 +255,6 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): dist.destroy_process_group(group=group) -def _require_cuda_usable(world_size: int) -> None: - """Skip if PyTorch cannot fully initialize CUDA (driver/toolkit skew, etc.).""" - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - try: - torch.cuda.set_device(0) - torch.empty(1, device="cuda", dtype=torch.float32) - torch.cuda.synchronize() - except RuntimeError as exc: - pytest.skip( - "PyTorch cannot initialize CUDA (match torch CUDA build to your driver, " - f"e.g. cu128 wheel for a 12.8-capable driver). Original error: {exc}" - ) - available = torch.cuda.device_count() - if world_size > available: - pytest.skip( - f"world_size {world_size} is greater than available_gpus {available}" - ) - - def get_open_port() -> int: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -293,14 +270,6 @@ def multi_process_parallel( world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) - # Spawn copies the parent's sys.path; workers unpickle targets from tests.comm.*. - repo_root = str(Path(__file__).resolve().parent.parent.parent) - if repo_root not in sys.path: - sys.path.insert(0, repo_root) - prev_pp = os.environ.get("PYTHONPATH", "") - os.environ["PYTHONPATH"] = ( - repo_root + (os.pathsep + prev_pp if prev_pp else "") - ) procs = [] distributed_init_port = get_open_port() @@ -322,8 +291,12 @@ def multi_process_parallel( def test_trtllm_moe_finalize_allreduce_fusion(world_size, dtype): np.random.seed(42) torch.manual_seed(42) - _require_cuda_usable(world_size) torch.cuda.manual_seed_all(42) + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + pytest.skip( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) print(f"Running test for world_size={world_size}") # generate shared random input tensor across all ranks From 64522a06709fa490a1c675b6d9121c63f95bb5f7 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Fri, 3 Apr 2026 07:49:17 +0000 Subject: [PATCH 4/6] formatting --- .../comm/trtllm_moe_allreduce_fusion.cuh | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index 6e2deaee9a..f7e36917c4 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -1278,6 +1278,8 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( int top_k = params.top_k; bool use_scale_factor = params.expert_scale_factor != nullptr; + float routed_scaling_factor = params.routed_scaling_factor; + bool use_routed_scaling_factor = routed_scaling_factor != 1.0f; // Persistent Kernel // Each cluster iterate through all token it need to handle @@ -1298,9 +1300,9 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( int thread_offset_across_token = permuted_idx * params.hidden_dim + thread_offset_within_token; - float block_scale = params.routed_scaling_factor; + float block_scale = 1.0f; if (use_scale_factor) { - block_scale *= + block_scale = static_cast(static_cast(params.expert_scale_factor)[expanded_idx]); } @@ -1308,10 +1310,22 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( permuted_data.load(reinterpret_cast(params.allreduce_in) + thread_offset_across_token); // * acc += scale(data) + if (use_scale_factor) { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + // assume computation is done in ScaleType + accumulator[i] += static_cast(static_cast(permuted_data[i]) * block_scale); + } + } else { + accumulator = vec_add(accumulator, permuted_data); + } + } + + // Apply the global routed scaling once after accumulating all routed experts. + if (use_routed_scaling_factor) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { - // assume computation is done in ScaleType - accumulator[i] += static_cast(static_cast(permuted_data[i]) * block_scale); + accumulator[i] = static_cast(static_cast(accumulator[i]) * routed_scaling_factor); } } @@ -1475,9 +1489,8 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams "allreduce_in, expanded_idx_to_permuted_idx and top_k must be set"); FLASHINFER_CHECK(params.size % params.hidden_dim == 0, "size must be a multiple of hidden_dim"); FLASHINFER_CHECK(params.hidden_dim % VEC_SIZE == 0, "hidden_dim must be a multiple of VEC_SIZE"); - FLASHINFER_CHECK( - params.moe_allreduce_out || params.residual_out || params.norm_out || params.quant_out, - "at least one of moe_allreduce_out, residual_out, norm_out, quant_out must be set"); + FLASHINFER_CHECK(params.residual_out || params.norm_out || params.quant_out, + "at least one of residual_out, norm_out, quant_out must be set"); auto status = DISPATCH_MOEFINALIZEREDUCTION( params.nranks, params.residual_out, params.rms_gamma, params.quant_out, N_RANKS, RES, RMS, From 779a53ed7dcc66ddd6594f6eb362ee1f2bd69a76 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Fri, 3 Apr 2026 08:13:55 +0000 Subject: [PATCH 5/6] . --- .../comm/trtllm_moe_allreduce_fusion.cuh | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index f7e36917c4..5a12fa737f 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -1300,17 +1300,14 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( int thread_offset_across_token = permuted_idx * params.hidden_dim + thread_offset_within_token; - float block_scale = 1.0f; - if (use_scale_factor) { - block_scale = - static_cast(static_cast(params.expert_scale_factor)[expanded_idx]); - } - vec_t permuted_data; permuted_data.load(reinterpret_cast(params.allreduce_in) + thread_offset_across_token); // * acc += scale(data) if (use_scale_factor) { + float block_scale = + routed_scaling_factor * + static_cast(static_cast(params.expert_scale_factor)[expanded_idx]); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { // assume computation is done in ScaleType @@ -1321,8 +1318,9 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( } } - // Apply the global routed scaling once after accumulating all routed experts. - if (use_routed_scaling_factor) { + bool fuse_routed_scaling_with_shared_add = + !use_scale_factor && use_routed_scaling_factor && params.shared_expert_output; + if (!use_scale_factor && use_routed_scaling_factor && !fuse_routed_scaling_with_shared_add) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { accumulator[i] = static_cast(static_cast(accumulator[i]) * routed_scaling_factor); @@ -1336,8 +1334,17 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( vec_t shared_expert_output; shared_expert_output.load(reinterpret_cast(params.shared_expert_output) + thread_offset_across_token); + if (fuse_routed_scaling_with_shared_add) { #pragma unroll - accumulator = vec_add(accumulator, shared_expert_output); + for (int i = 0; i < VEC_SIZE; ++i) { + accumulator[i] = + static_cast(static_cast(accumulator[i]) * routed_scaling_factor + + static_cast(shared_expert_output[i])); + } + } else { +#pragma unroll + accumulator = vec_add(accumulator, shared_expert_output); + } } // * AR Store From a41fd42f3e86b7c53a399b9c338f3ebb97c04e6b Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Fri, 3 Apr 2026 08:31:58 +0000 Subject: [PATCH 6/6] . --- .../comm/trtllm_moe_allreduce_fusion.cuh | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index 5a12fa737f..5d731bbbfc 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -1306,7 +1306,6 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( // * acc += scale(data) if (use_scale_factor) { float block_scale = - routed_scaling_factor * static_cast(static_cast(params.expert_scale_factor)[expanded_idx]); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -1318,9 +1317,8 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( } } - bool fuse_routed_scaling_with_shared_add = - !use_scale_factor && use_routed_scaling_factor && params.shared_expert_output; - if (!use_scale_factor && use_routed_scaling_factor && !fuse_routed_scaling_with_shared_add) { + // Apply the global routed scaling once after accumulating all routed experts. + if (use_routed_scaling_factor) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { accumulator[i] = static_cast(static_cast(accumulator[i]) * routed_scaling_factor); @@ -1334,17 +1332,8 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( vec_t shared_expert_output; shared_expert_output.load(reinterpret_cast(params.shared_expert_output) + thread_offset_across_token); - if (fuse_routed_scaling_with_shared_add) { #pragma unroll - for (int i = 0; i < VEC_SIZE; ++i) { - accumulator[i] = - static_cast(static_cast(accumulator[i]) * routed_scaling_factor + - static_cast(shared_expert_output[i])); - } - } else { -#pragma unroll - accumulator = vec_add(accumulator, shared_expert_output); - } + accumulator = vec_add(accumulator, shared_expert_output); } // * AR Store