From 13bb23a30143901777b331fa79c4ee1f147a34e7 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Sat, 25 Oct 2025 16:54:42 +0800 Subject: [PATCH 01/13] enable W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod for EPLB Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/modules/fused_moe/quantization.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 1e56f90d5e9..73ceababd5e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -2558,8 +2558,8 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, w1_weight: torch.Tensor, w3_weight: torch.Tensor, dst_w3_w1_weight: torch.Tensor): - device = dst_w3_w1_weight.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") + non_blocking = dst_w3_w1_weight.device.type == "cuda" # We pad before the sharding. This is done to avoid fractional scaling factors # per shard. @@ -2621,13 +2621,13 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, # Copy the result into device buffer dst_w3_w1_weight.copy_(processed_w31_weight_shard.view( dst_w3_w1_weight.dtype), - non_blocking=True) + non_blocking=non_blocking) def load_expert_w2_weight(self, module: torch.nn.Module, w2_weight: torch.Tensor, dst_w2_weight: torch.Tensor): - device = dst_w2_weight.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") + non_blocking = dst_w2_weight.device.type == "cuda" shard_w2_weight_dim = 2 * w2_weight.shape[1] if len( w2_weight.shape) == 2 else w2_weight.shape[0] @@ -2658,7 +2658,7 @@ def load_expert_w2_weight(self, module: torch.nn.Module, # Keep weights in device buffer dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), - non_blocking=True) + non_blocking=non_blocking) # Get permuted indices permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( dst_w2_weight, self._cache_permute_indices, epilogue_tile_m) @@ -2669,14 +2669,13 @@ def load_expert_w2_weight(self, module: torch.nn.Module, # Copy the result into device buffer dst_w2_weight.copy_(processed_w2_weight.view(dst_w2_weight.dtype), - non_blocking=True) + non_blocking=non_blocking) def load_expert_w3_w1_weight_scale_mxfp4( self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor): - device = dst_w3_w1_weight_scale.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, @@ -2738,8 +2737,7 @@ def load_expert_w3_w1_weight_scale_mxfp4( def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor): - device = dst_w2_weight_scale.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, From e00300c0b8d44faf24cb20ca7c82dda2f0b16d54 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 27 Oct 2025 10:53:19 +0800 Subject: [PATCH 02/13] refactor eplb logic to base MoE class Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/modules/fused_moe/create_moe.py | 5 +- .../modules/fused_moe/fused_moe_trtllm_gen.py | 28 +- .../modules/fused_moe/fused_moe_wide_ep.py | 166 ++---------- .../_torch/modules/fused_moe/interface.py | 253 +++++++++++++++++- 4 files changed, 294 insertions(+), 158 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index d67000f5e27..2ff8a0121f0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -79,7 +79,9 @@ def create_moe( moe_load_balancer = get_moe_load_balancer() if moe_load_balancer is not None: - assert moe_cls == WideEPMoE, "MoE Load Balance is only supported in WideEPMoE now." + assert moe_cls in [ + WideEPMoE, TRTLLMGenFusedMoE + ], "MoE Load Balance is only supported in WideEPMoE and TRTLLMGenFusedMoE now." if bias: assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE @@ -106,6 +108,7 @@ def create_moe( dtype=dtype, reduce_results=reduce_results, model_config=model_config, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, bias=bias, layer_idx=layer_idx, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 24bd9b0d909..c43a463bd40 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -14,7 +14,7 @@ fp4_block_scale_fake_output_without_finalize from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import Fp4QuantizedTensor, ceil_div +from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEMethod, @@ -37,6 +37,7 @@ class TRTLLMGenFusedMoE(MoE): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. MoE torch custom op: Only support min-latency mode now (SM100 Blackwell only). @@ -66,6 +67,8 @@ def __init__( dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, layer_idx: Optional[int] = None, @@ -82,6 +85,7 @@ def __init__( dtype=dtype, reduce_results=reduce_results, model_config=model_config, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, bias=bias, swiglu_alpha=swiglu_alpha, @@ -97,19 +101,11 @@ def __init__( assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." - self.num_slots = self.num_experts - self.expert_size_per_partition = self.num_experts // self.ep_size - self.initial_global_assignments = [ - (ep_rank * self.num_experts // self.ep_size + local_slot_id) % - self.num_experts for ep_rank in range(self.ep_size) - for local_slot_id in range(self.expert_size_per_partition) - ] - self.slot_start = self.ep_rank * self.expert_size_per_partition - self.slot_end = self.slot_start + self.expert_size_per_partition - self.initial_local_expert_ids = self.initial_global_assignments[ - self.slot_start:self.slot_end] - assert len( - self.initial_local_expert_ids) == self.expert_size_per_partition + # Note: Load balancer initialization is handled by base class _init_load_balancer() + # If no load balancer is available, the base class will set: + # - self.num_slots = self.num_experts + # - self.expert_size_per_partition = self.num_experts // self.ep_size + # - self.initial_global_assignments, self.slot_start, self.slot_end, etc. # TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future. self.alltoall_method_type = self.select_alltoall_method_type() @@ -183,6 +179,10 @@ def select_alltoall_method_type(self) -> AlltoallMethodType: return AlltoallMethodType.MNNVL + def _supports_load_balancer(self) -> bool: + """TRTLLMGenFusedMoE supports load balancer.""" + return True + @cached_property def enable_alltoall(self): """ enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 6f31a44d5a6..3f9611d889d 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -5,17 +5,17 @@ from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo from tensorrt_llm._utils import is_sm_100f, local_mpi_size -from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping -from ...distributed import AllReduce, allgather, reducescatter +from ...distributed import allgather, reducescatter from ...expert_statistic import ExpertStatistic from ...model_config import ModelConfig from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import AlltoallMethodType, MoE from .moe_load_balancer import get_moe_load_balancer +from .interface import MoE from .ops import MoEOp, MoEOpSelector from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, @@ -71,6 +71,7 @@ def __init__( dtype=dtype, reduce_results=reduce_results, model_config=model_config, + aux_stream_dict=aux_stream_dict, weight_loading_mode=weight_loading_mode, layer_idx=layer_idx, ) @@ -83,68 +84,8 @@ def __init__( # Store original hidden size before any potential padding self.unpadded_hidden_size = self.hidden_size - moe_load_balancer = get_moe_load_balancer() - self.layer_load_balancer = None - self.repeat_idx = 0 - self.repeat_count = 1 - self.use_cuda_graph = model_config.use_cuda_graph - moe_load_balancer_config = model_config.moe_load_balancer - init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size - self.initial_global_assignments = [ - (ep_rank * self.num_experts // self.ep_size + local_slot_id) % - self.num_experts for ep_rank in range(self.ep_size) - for local_slot_id in range(init_expert_size_per_partition) - ] - - if moe_load_balancer: - assert moe_load_balancer_config is not None - top_k = self.routing_method.experts_per_token - self.expert_size_per_partition = moe_load_balancer_config.num_local_slots - self.layer_load_balancer = moe_load_balancer.add_layer( - self.num_experts, - top_k, - self.expert_size_per_partition, - aux_stream=None if aux_stream_dict is None else - aux_stream_dict[AuxStreamType.MoeBalancer]) - self.repeat_count = self.layer_load_balancer.get_repeat_count() - loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( - self.layer_idx) - self.num_slots = moe_load_balancer_config.num_slots - if loaded_initial_global_assignments is not None: - assert isinstance(loaded_initial_global_assignments, list) - assert len(loaded_initial_global_assignments) == self.num_slots - assert self.num_slots >= self.num_experts - assert set(loaded_initial_global_assignments) == set( - range(self.num_experts)) - self.initial_global_assignments = loaded_initial_global_assignments - self.layer_load_balancer.set_initial_weight_assignments( - self.initial_global_assignments) - logger.info( - f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}" - ) - logger.info( - f"initial_global_assignments (layer {self.layer_idx}) = {self.initial_global_assignments}" - ) - else: - assert num_experts % self.ep_size == 0 - self.expert_size_per_partition = num_experts // self.ep_size - self.num_slots = num_experts - if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ): - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=AllReduceStrategy.NCCL) - else: - self.allreduce = None - - self.slot_start = self.ep_rank * self.expert_size_per_partition - self.slot_end = self.slot_start + self.expert_size_per_partition - self.initial_local_expert_ids = self.initial_global_assignments[ - self.slot_start:self.slot_end] - assert len( - self.initial_local_expert_ids) == self.expert_size_per_partition - # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens @@ -460,8 +401,7 @@ def forward_chunk( is_first_call, is_last_call = repeating_info - if self.layer_load_balancer and is_first_call: - self.layer_load_balancer.start_wait_gpu_stage() + self._load_balancer_start_wait_gpu_stage(is_first_call) if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL: alltoall_result_do_sum = True @@ -491,20 +431,12 @@ def forward_chunk( token_final_scales = None if self.layer_load_balancer: - if is_first_call: - self.layer_load_balancer.done_wait_gpu_stage() - if use_all_to_all and self.alltoall_method_type == AlltoallMethodType.MNNVL: - self.layer_load_balancer.update_local_statistic( - token_selected_experts, - is_first_stage=is_first_call, - is_last_stage=is_last_call) - else: - self.layer_load_balancer.update_statistic_with_local_ids( - token_selected_experts, - is_first_stage=is_first_call, - is_last_stage=is_last_call, - allreduce=self.allreduce) - token_selected_slots = self.layer_load_balancer.route( + self._load_balancer_done_wait_gpu_stage(is_first_call) + use_mnnvl = use_all_to_all and self.alltoall_method_type == AlltoallMethodType.MNNVL + self._load_balancer_update_statistic(token_selected_experts, + is_first_call, is_last_call, + use_mnnvl) + token_selected_slots = self._load_balancer_route( token_selected_experts, self.use_dp) else: token_selected_slots = token_selected_experts @@ -538,9 +470,8 @@ def forward_chunk( if self.enable_dummy_allreduce: self.dummy_allreduce() token_count = x.shape[0] - if is_last_call and self.layer_load_balancer is not None and not self.layer_load_balancer.is_static_routing( - ): - loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( + if is_last_call: + loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( ) else: loadbalancer_local_statistic_info = None @@ -552,7 +483,7 @@ def forward_chunk( if gathered_loadbalancer_local_statistic_info is not None: gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( (self.mapping.moe_ep_size, self.num_experts)) - self.layer_load_balancer.update_statistic_with_gathered_statistic( + self._load_balancer_update_statistic_with_gathered_statistic( gathered_loadbalancer_local_statistic_info) elif self.alltoall_method_type == AlltoallMethodType.DeepEP: if not use_postquant_alltoall: @@ -717,8 +648,7 @@ def forward_chunk( tuner_top_k=tuner_top_k, ) - if self.layer_load_balancer and is_last_call: - self.layer_load_balancer.start_set_cpu_stage() + self._load_balancer_start_set_cpu_stage(is_last_call) # Only in cutlass_min_latency_mode, the output is a list of tensors. # Otherwise, the output should be unpacked as a single tensor. @@ -761,8 +691,7 @@ def forward_chunk( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - if self.layer_load_balancer and is_last_call: - self.layer_load_balancer.done_set_cpu_stage() + self._load_balancer_done_set_cpu_stage(is_last_call) return final_hidden_states @@ -1003,62 +932,15 @@ def unpad_tensors(self, padded: bool, final_hidden_states = final_hidden_states[:0] return final_hidden_states - def register_parameter_weight_slot_fn(self, weight_name: str, - local_slot_id: int): - assert hasattr( - self, - weight_name), f"FusedMoE doesn't has weight attr: {weight_name}" - weight_tensor = getattr(self, weight_name).data[local_slot_id] - self.layer_load_balancer.register_weight_slot(local_slot_id, - weight_name, - weight_tensor) - - def register_to_fix_weight_fn(self, weight_name: str): - assert hasattr( - self, - weight_name), f"FusedMoE doesn't has weight attr: {weight_name}" - param = getattr(self, weight_name) - weight_tensor = param.detach() - assert isinstance( - weight_tensor, - torch.Tensor), f'weight {weight_name} should be a tensor' - assert weight_tensor.is_contiguous( - ), f'weight {weight_name} should be a is_contiguous, shape={weight_tensor.shape}, strides={weight_tensor.is_contiguous()}' - assert weight_tensor.numel() * weight_tensor.element_size() == weight_tensor.untyped_storage().size(),\ - f'weight {weight_name} shape={weight_tensor.shape} storage_size = {weight_tensor.untyped_storage().size()}, numel={weight_tensor.numel()}, eltsize={weight_tensor.element_size()}, dtype={weight_tensor.dtype}' - self.layer_load_balancer.make_tensor_host_accessible(weight_tensor) - param.data = weight_tensor - - def register_all_parameter_slot_and_to_fix_weight_fns( - self, weight_and_tensor_dict: Dict[str, torch.Tensor]): - """ - weight_and_tensor_dict: key is the name of the weight, value is the tensor of loaded shared tensor shard. - E.g. if num_experts=256 and 4 GPUs per node, then each rank need to load 256 / 4 = 64 expert weights for host sharing. - By this way, host_tensor_sharer can share the weights and each rank has access to all 256 experts. - """ - for local_slot_id, expert_id in enumerate( - self.initial_local_expert_ids): - for weight_name in weight_and_tensor_dict: - self.layer_load_balancer.add_register_weight_fn( - self.register_parameter_weight_slot_fn, - (weight_name, local_slot_id)) - for weight_name in weight_and_tensor_dict: - self.layer_load_balancer.add_to_migrate_weight_fn( - self.register_to_fix_weight_fn, (weight_name, )) - - local_shared_load_expert_ids = self.layer_load_balancer.get_load_expert_ids( - ) - for expert_id in range(self.num_experts): - for weight_name, weight_tensor in weight_and_tensor_dict.items(): - if expert_id in local_shared_load_expert_ids: - local_slot_id = local_shared_load_expert_ids.index( - expert_id) - self.layer_load_balancer.host_tensor_sharer.share_host_tensor_with_shape( - expert_id, weight_name, weight_tensor[local_slot_id]) - else: - self.layer_load_balancer.host_tensor_sharer.pre_register_host_tensor_with_shape( - expert_id, weight_name, weight_tensor.dtype, - weight_tensor[0].shape) + def _supports_load_balancer(self) -> bool: + """WideEPMoE supports load balancer.""" + return True + + def _get_load_balancer_aux_stream(self) -> Optional[torch.cuda.Stream]: + """Get auxiliary stream for load balancer from aux_stream_dict.""" + if self.aux_stream_dict is not None: + return self.aux_stream_dict.get(AuxStreamType.MoeBalancer) + return None def load_weights(self, weights: List[Dict]): assert self._weights_created diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 54589a71e5a..114f0e231e8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -8,7 +8,7 @@ from ...distributed.ops import reducescatter from ...model_config import ModelConfig -from ...utils import (Fp4QuantizedTensor, get_model_extra_attrs, +from ...utils import (AuxStreamType, Fp4QuantizedTensor, get_model_extra_attrs, is_torch_compiling) from .routing import BaseMoeRoutingMethod @@ -122,6 +122,7 @@ class MoE(nn.Module): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. + aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping. """ def __init__( @@ -134,6 +135,8 @@ def __init__( dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. VANILLA, bias: bool = False, @@ -188,6 +191,254 @@ def __init__( strategy=model_config.allreduce_strategy, dtype=self.dtype) + # Initialize load balancer related attributes + self._init_load_balancer(model_config, aux_stream_dict) + + def _init_load_balancer( + self, + model_config: ModelConfig, + aux_stream_dict: Optional[Dict[AuxStreamType, + torch.cuda.Stream]] = None, + ): + """Initialize load balancer related attributes.""" + from .moe_load_balancer import get_moe_load_balancer + + # Store aux_stream_dict for load balancer + self.aux_stream_dict = aux_stream_dict + + # Initialize load balancer attributes + self.layer_load_balancer = None + self.repeat_idx = 0 + self.repeat_count = 1 + + # Get global load balancer instance + moe_load_balancer = get_moe_load_balancer() + moe_load_balancer_config = model_config.moe_load_balancer + + # Calculate initial expert assignments + init_expert_size_per_partition = ( + moe_load_balancer_config.num_local_slots + if moe_load_balancer_config else self.num_experts // self.ep_size) + + self.initial_global_assignments = [ + (ep_rank * self.num_experts // self.ep_size + local_slot_id) % + self.num_experts for ep_rank in range(self.ep_size) + for local_slot_id in range(init_expert_size_per_partition) + ] + + # Setup load balancer if available + if moe_load_balancer and self._supports_load_balancer(): + assert moe_load_balancer_config is not None + top_k = self.routing_method.experts_per_token + self.expert_size_per_partition = moe_load_balancer_config.num_local_slots + + # Add this layer to the load balancer + aux_stream = getattr(self, '_get_load_balancer_aux_stream', + lambda: None)() + self.layer_load_balancer = moe_load_balancer.add_layer( + self.num_experts, + top_k, + self.expert_size_per_partition, + aux_stream=aux_stream) + + self.repeat_count = self.layer_load_balancer.get_repeat_count() + + # Handle initial global assignments + loaded_initial_global_assignments = ( + moe_load_balancer_config.get_layer_initial_global_assignments( + self.layer_idx)) + self.num_slots = moe_load_balancer_config.num_slots + + if loaded_initial_global_assignments is not None: + assert isinstance(loaded_initial_global_assignments, list) + assert len(loaded_initial_global_assignments) == self.num_slots + assert self.num_slots >= self.num_experts + assert set(loaded_initial_global_assignments) == set( + range(self.num_experts)) + self.initial_global_assignments = loaded_initial_global_assignments + + self.layer_load_balancer.set_initial_weight_assignments( + self.initial_global_assignments) + + from tensorrt_llm.logger import logger + logger.info( + f"MoE load balancer enabled. num_experts = {self.num_experts}, " + f"num_slots = {self.num_slots}, ep_size = {self.ep_size}") + logger.info( + f"initial_global_assignments (layer {self.layer_idx}) = {self.initial_global_assignments}" + ) + else: + # Fallback when no load balancer + assert self.num_experts % self.ep_size == 0 + self.expert_size_per_partition = self.num_experts // self.ep_size + self.num_slots = self.num_experts + + # Calculate slot boundaries + self.slot_start = self.ep_rank * self.expert_size_per_partition + self.slot_end = self.slot_start + self.expert_size_per_partition + self.initial_local_expert_ids = self.initial_global_assignments[ + self.slot_start:self.slot_end] + assert len( + self.initial_local_expert_ids) == self.expert_size_per_partition + + # Setup AllReduce for dynamic routing if needed + if (self.layer_load_balancer + and not self.layer_load_balancer.is_static_routing()): + from tensorrt_llm.functional import AllReduceStrategy + + from ...distributed import AllReduce + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=AllReduceStrategy.NCCL) + else: + self.allreduce = None + + def _supports_load_balancer(self) -> bool: + """Check if this MoE implementation supports load balancer. + + Subclasses can override this to indicate load balancer support. + """ + return False + + def _get_load_balancer_aux_stream(self) -> Optional[torch.cuda.Stream]: + """Get auxiliary stream for load balancer. + + Subclasses can override this to provide the appropriate stream. + """ + return None + + def _load_balancer_start_wait_gpu_stage(self, is_first_call: bool): + """Start waiting for GPU stage in load balancer.""" + if self.layer_load_balancer and is_first_call: + self.layer_load_balancer.start_wait_gpu_stage() + + def _load_balancer_done_wait_gpu_stage(self, is_first_call: bool): + """Mark GPU wait stage as done in load balancer.""" + if self.layer_load_balancer and is_first_call: + self.layer_load_balancer.done_wait_gpu_stage() + + def _load_balancer_update_statistic(self, + token_selected_experts: torch.Tensor, + is_first_call: bool, + is_last_call: bool, + use_mnnvl: bool = False): + """Update load balancer statistics.""" + if self.layer_load_balancer: + if use_mnnvl: + self.layer_load_balancer.update_local_statistic( + token_selected_experts, + is_first_stage=is_first_call, + is_last_stage=is_last_call) + else: + self.layer_load_balancer.update_statistic_with_local_ids( + token_selected_experts, + is_first_stage=is_first_call, + is_last_stage=is_last_call, + allreduce=self.allreduce) + + def _load_balancer_route(self, token_selected_experts: torch.Tensor, + use_dp: bool) -> torch.Tensor: + """Route tokens using load balancer.""" + if self.layer_load_balancer: + return self.layer_load_balancer.route(token_selected_experts, + use_dp) + else: + return token_selected_experts + + def _load_balancer_start_set_cpu_stage(self, is_last_call: bool): + """Start CPU stage in load balancer.""" + if self.layer_load_balancer and is_last_call: + self.layer_load_balancer.start_set_cpu_stage() + + def _load_balancer_done_set_cpu_stage(self, is_last_call: bool): + """Mark CPU stage as done in load balancer.""" + if self.layer_load_balancer and is_last_call: + self.layer_load_balancer.done_set_cpu_stage() + + def _load_balancer_get_local_statistic_tensor(self): + """Get local statistic tensor from load balancer.""" + if (self.layer_load_balancer + and not self.layer_load_balancer.is_static_routing()): + return self.layer_load_balancer.get_local_statistic_tensor() + return None + + def _load_balancer_update_statistic_with_gathered_statistic( + self, gathered_statistic): + """Update load balancer with gathered statistics.""" + if self.layer_load_balancer: + self.layer_load_balancer.update_statistic_with_gathered_statistic( + gathered_statistic) + + def register_parameter_weight_slot_fn(self, weight_name: str, + local_slot_id: int): + """Register parameter weight slot function for load balancer.""" + if not self.layer_load_balancer: + return + + assert hasattr( + self, weight_name), f"MoE doesn't have weight attr: {weight_name}" + weight_tensor = getattr(self, weight_name).data[local_slot_id] + self.layer_load_balancer.register_weight_slot(local_slot_id, + weight_name, + weight_tensor) + + def register_to_fix_weight_fn(self, weight_name: str): + """Register weight fixing function for load balancer.""" + if not self.layer_load_balancer: + return + + assert hasattr( + self, weight_name), f"MoE doesn't have weight attr: {weight_name}" + param = getattr(self, weight_name) + weight_tensor = param.detach() + assert isinstance( + weight_tensor, + torch.Tensor), f'weight {weight_name} should be a tensor' + assert weight_tensor.is_contiguous(), ( + f'weight {weight_name} should be contiguous, ' + f'shape={weight_tensor.shape}, strides={weight_tensor.stride()}') + assert weight_tensor.numel() * weight_tensor.element_size( + ) == weight_tensor.untyped_storage().size(), ( + f'weight {weight_name} shape={weight_tensor.shape} ' + f'storage_size = {weight_tensor.untyped_storage().size()}, ' + f'numel={weight_tensor.numel()}, eltsize={weight_tensor.element_size()}, ' + f'dtype={weight_tensor.dtype}') + self.layer_load_balancer.make_tensor_host_accessible(weight_tensor) + param.data = weight_tensor + + def register_all_parameter_slot_and_to_fix_weight_fns( + self, weight_and_tensor_dict: Dict[str, torch.Tensor]): + """Register all parameter slot and weight fixing functions for load balancer.""" + if not self.layer_load_balancer: + return + + # Register weight functions for each local slot + for local_slot_id, expert_id in enumerate( + self.initial_local_expert_ids): + for weight_name in weight_and_tensor_dict: + self.layer_load_balancer.add_register_weight_fn( + self.register_parameter_weight_slot_fn, + (weight_name, local_slot_id)) + + # Register weight migration functions + for weight_name in weight_and_tensor_dict: + self.layer_load_balancer.add_to_migrate_weight_fn( + self.register_to_fix_weight_fn, (weight_name, )) + + # Setup host tensor sharing + local_shared_load_expert_ids = self.layer_load_balancer.get_load_expert_ids( + ) + for expert_id in range(self.num_experts): + for weight_name, weight_tensor in weight_and_tensor_dict.items(): + if expert_id in local_shared_load_expert_ids: + local_slot_id = local_shared_load_expert_ids.index( + expert_id) + self.layer_load_balancer.host_tensor_sharer.share_host_tensor_with_shape( + expert_id, weight_name, weight_tensor[local_slot_id]) + else: + self.layer_load_balancer.host_tensor_sharer.pre_register_host_tensor_with_shape( + expert_id, weight_name, weight_tensor.dtype, + weight_tensor[0].shape) + def _register_layer(self, model_config: ModelConfig): self.register_to_config = False if model_config is not None and self.layer_idx_str is not None: From 04774261db329d883d28afe2bf3f0216262dd7af Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 27 Oct 2025 14:41:23 +0800 Subject: [PATCH 03/13] Add EPLB support for TRTLLMGEN backend(mxfp4 and nvfp4) and ut Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/models/modeling_gpt_oss.py | 9 +- .../modules/fused_moe/fused_moe_trtllm_gen.py | 40 ++++ .../modules/fused_moe/fused_moe_wide_ep.py | 8 - .../_torch/modules/fused_moe/interface.py | 6 +- .../modules/fused_moe/moe_load_balancer.py | 1 + .../_torch/modules/fused_moe/quantization.py | 211 +++++++++++++----- .../defs/accuracy/test_llm_api_pytorch.py | 48 +++- 7 files changed, 243 insertions(+), 80 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index 64c96a191e7..546d17d5985 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -140,6 +140,9 @@ def __init__( self.config = config # Store config as instance variable pretrained_config = config.pretrained_config self.num_experts = pretrained_config.num_local_experts + moe_load_balancer_config = config.moe_load_balancer + self.num_slots = moe_load_balancer_config.num_slots if moe_load_balancer_config and moe_load_balancer_config.num_slots else self.num_experts + self.layer_idx = layer_idx self.enable_attention_dp = config.mapping.enable_attention_dp self.mapping = config.mapping @@ -162,13 +165,13 @@ def __init__( if config.moe_backend.upper() == "TRTLLM" else torch.float32) self.swiglu_alpha = torch.tensor( - [1.702] * (self.num_experts // config.mapping.moe_ep_size), + [1.702] * (self.num_slots // config.mapping.moe_ep_size), dtype=torch.float32).cuda() self.swiglu_beta = torch.tensor( - [1.0] * (self.num_experts // config.mapping.moe_ep_size), + [1.0] * (self.num_slots // config.mapping.moe_ep_size), dtype=torch.float32).cuda() self.swiglu_limit = torch.tensor( - [7.0] * (self.num_experts // config.mapping.moe_ep_size), + [7.0] * (self.num_slots // config.mapping.moe_ep_size), dtype=torch.float32).cuda() # Prepare MoE creation parameters moe_params = { diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index c43a463bd40..e8fcdd959d6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -13,6 +13,7 @@ from ...custom_ops.trtllm_gen_custom_ops import \ fp4_block_scale_fake_output_without_finalize from ...distributed import allgather +from ...expert_statistic import ExpertStatistic from ...model_config import ModelConfig from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode @@ -348,6 +349,34 @@ def forward_impl( if token_final_scales is not None: token_final_scales = token_final_scales.to(torch.bfloat16) + # Apply load balancer routing if available + if self.layer_load_balancer: + # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking) + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 + + # Start GPU stage for first call + self._load_balancer_start_wait_gpu_stage(is_first_call) + self._load_balancer_done_wait_gpu_stage(is_first_call) + + # Update load balancer statistics (TRTLLMGenFusedMoE doesn't use MNNVL) + self._load_balancer_update_statistic(token_selected_experts, + is_first_call, + is_last_call, + use_mnnvl=False) + + # Route tokens to slots + token_selected_slots = self._load_balancer_route( + token_selected_experts, self.use_dp) + + # Update expert statistics + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, + token_selected_slots) + + # Use routed slots for subsequent processing + token_selected_experts = token_selected_slots + x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x) if self.enable_alltoall: @@ -763,10 +792,21 @@ def forward_impl( use_dp_padding=use_dp_padding, ) + # Handle load balancer CPU stage if needed + if self.layer_load_balancer: + is_last_call = self.repeat_idx == self.repeat_count - 1 + self._load_balancer_start_set_cpu_stage(is_last_call) + self._load_balancer_done_set_cpu_stage(is_last_call) + if use_dp_padding: rank = self.mapping.tp_rank final_hidden_states = final_hidden_states[: all_rank_num_tokens[rank]] + + # Update repeat index for load balancer + if self.layer_load_balancer: + self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 + return final_hidden_states def forward_fake( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 3f9611d889d..0496e4172ca 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -14,8 +14,6 @@ from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import AlltoallMethodType, MoE -from .moe_load_balancer import get_moe_load_balancer -from .interface import MoE from .ops import MoEOp, MoEOpSelector from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, @@ -936,12 +934,6 @@ def _supports_load_balancer(self) -> bool: """WideEPMoE supports load balancer.""" return True - def _get_load_balancer_aux_stream(self) -> Optional[torch.cuda.Stream]: - """Get auxiliary stream for load balancer from aux_stream_dict.""" - if self.aux_stream_dict is not None: - return self.aux_stream_dict.get(AuxStreamType.MoeBalancer) - return None - def load_weights(self, weights: List[Dict]): assert self._weights_created assert len(weights) == 1 diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 114f0e231e8..b417574d656 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -300,10 +300,12 @@ def _supports_load_balancer(self) -> bool: return False def _get_load_balancer_aux_stream(self) -> Optional[torch.cuda.Stream]: - """Get auxiliary stream for load balancer. + """Get auxiliary stream for load balancer from aux_stream_dict. - Subclasses can override this to provide the appropriate stream. + Returns the MoeBalancer stream from aux_stream_dict if available. """ + if self.aux_stream_dict is not None: + return self.aux_stream_dict.get(AuxStreamType.MoeBalancer) return None def _load_balancer_start_wait_gpu_stage(self, is_first_call: bool): diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index 530927d9ba7..6f0bfd64390 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -957,6 +957,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): moe_model_arch_list = [ 'DeepseekV3ForCausalLM', + 'GptOssForCausalLM', 'MixtralForCausalLM', 'Llama4ForConditionalGeneration', 'Qwen2MoeForCausalLM', diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 73ceababd5e..74cec95e865 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1626,25 +1626,38 @@ def load_expert_fc2_alpha_nvfp4(self, w2_weight_scale_2, dst_w2_alpha.copy_(1.0 / (final_fc2_input_scale * w2_weight_scale_2)) def load_all_fp4_weight_scales_and_alphas( - self, module: torch.nn.Module, weights: Dict, - load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor, - dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor, - dst_fc2_alpha: torch.Tensor): + self, + module: torch.nn.Module, + weights: Dict, + load_expert_ids: List[int], + dst_w3_w1_weight_scale: Optional[torch.Tensor], + dst_w2_weight_scale: Optional[torch.Tensor], + dst_fc31_alpha: torch.Tensor, + dst_fc2_alpha: torch.Tensor, + ignore_weight_scale=False): + w1_weight_scale = None + w2_weight_scale = None + w3_weight_scale = None + if not ignore_weight_scale: + assert dst_w3_w1_weight_scale is not None + assert dst_w2_weight_scale is not None for local_slot_id, expert_id in enumerate(load_expert_ids): if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] - w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] - w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + if not ignore_weight_scale: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"] w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"] w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"] elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][ - expert_id].transpose(0, 1).contiguous() - w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk( - 2, dim=0) - w2_weight_scale = weights["down_proj_weight_scale"][ - expert_id].transpose(0, 1).contiguous() + if not ignore_weight_scale: + w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk( + 2, dim=0) + w2_weight_scale = weights["down_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] w2_weight_scale_2 = weights["down_proj_weight_scale_2"] @@ -1663,11 +1676,12 @@ def load_all_fp4_weight_scales_and_alphas( w3_weight_scale_2) w3_weight_scale_2 = w1_weight_scale_2 - self.load_expert_w3_w1_weight_scale_nvfp4( - module, w1_weight_scale, w3_weight_scale, - dst_w3_w1_weight_scale[expert_idx]) - self.load_expert_w2_weight_scale_nvfp4( - module, w2_weight_scale, dst_w2_weight_scale[expert_idx]) + if not ignore_weight_scale: + self.load_expert_w3_w1_weight_scale_nvfp4( + module, w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale[expert_idx]) + self.load_expert_w2_weight_scale_nvfp4( + module, w2_weight_scale, dst_w2_weight_scale[expert_idx]) self.load_expert_fc31_alpha_nvfp4(w1_weight_scale_2, w3_weight_scale_2, @@ -1873,8 +1887,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, w1_weight: torch.Tensor, w3_weight: torch.Tensor, dst_w3_w1_weight: torch.Tensor): - device = dst_w3_w1_weight.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w3_w1_weight.device.type == "cuda" + dst_w3_w1_weight_gpu = dst_w3_w1_weight if dst_on_gpu else dst_w3_w1_weight.cuda( + ) + w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, module.tp_rank, @@ -1890,7 +1907,7 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, epilogue_tile_m = 128 # Keep weights in device buffer - dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.split( + dst_w3_weight, dst_w1_weight = dst_w3_w1_weight_gpu.split( module.intermediate_size_per_partition, dim=0) dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype)) @@ -1898,22 +1915,28 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, # Get permute indices permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( - dst_w3_w1_weight, self._cache_permute_indices, epilogue_tile_m) + dst_w3_w1_weight_gpu, self._cache_permute_indices, epilogue_tile_m) # Shuffle the weight according to permute indices processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix( - dst_w3_w1_weight, permute_indices.to(dst_w3_w1_weight.device)) + dst_w3_w1_weight_gpu, + permute_indices.to(dst_w3_w1_weight_gpu.device)) # Copy the result into device buffer - dst_w3_w1_weight.copy_(processed_w31_weight_shard.view( - dst_w3_w1_weight.dtype), - non_blocking=True) + dst_w3_w1_weight_gpu.copy_(processed_w31_weight_shard.view( + dst_w3_w1_weight_gpu.dtype), + non_blocking=dst_on_gpu) + if not dst_on_gpu: + dst_w3_w1_weight.copy_(dst_w3_w1_weight_gpu) def load_expert_w2_weight(self, module: torch.nn.Module, w2_weight: torch.Tensor, dst_w2_weight: torch.Tensor): - device = dst_w2_weight.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w2_weight.device.type == "cuda" + dst_w2_weight_gpu = dst_w2_weight if dst_on_gpu else dst_w2_weight.cuda( + ) + w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, module.tp_rank, @@ -1924,19 +1947,23 @@ def load_expert_w2_weight(self, module: torch.nn.Module, epilogue_tile_m = 128 # Keep weights in device buffer - dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), - non_blocking=True) + dst_w2_weight_gpu.copy_(w2_weight_shard.view(dst_w2_weight_gpu.dtype), + non_blocking=dst_on_gpu) # Get permuted indices permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( - dst_w2_weight, self._cache_permute_indices, epilogue_tile_m) + dst_w2_weight_gpu, self._cache_permute_indices, epilogue_tile_m) # Shuffle the weight according to permute indices processed_w2_weight = torch.ops.trtllm.shuffle_matrix( - dst_w2_weight, permute_indices.to(dst_w2_weight.device)) + dst_w2_weight_gpu, permute_indices.to(dst_w2_weight_gpu.device)) # Copy the result into device buffer - dst_w2_weight.copy_(processed_w2_weight.view(dst_w2_weight.dtype), - non_blocking=True) + dst_w2_weight_gpu.copy_(processed_w2_weight.view( + dst_w2_weight_gpu.dtype), + non_blocking=dst_on_gpu) + + if not dst_on_gpu: + dst_w2_weight.copy_(dst_w2_weight_gpu) def load_expert_w3_w1_weight_scale_nvfp4( self, @@ -1945,8 +1972,11 @@ def load_expert_w3_w1_weight_scale_nvfp4( w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor, num_elts_per_sf: int = 16): - device = dst_w3_w1_weight_scale.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w3_w1_weight_scale.device.type == "cuda" + dst_w3_w1_weight_scale_gpu = dst_w3_w1_weight_scale if dst_on_gpu else dst_w3_w1_weight_scale.cuda( + ) + w1_weight_scale = load_weight_shard(w1_weight_scale, module.tp_size, module.tp_rank, @@ -1959,34 +1989,34 @@ def load_expert_w3_w1_weight_scale_nvfp4( device=device) # Keep weights in device buffer # w3 - dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow( + dst_w3_weight_scale = dst_w3_w1_weight_scale_gpu.narrow( dim=0, start=0, length=module.intermediate_size_per_partition) dst_w3_weight_scale.copy_( w3_weight_scale.view(dst_w3_weight_scale.dtype)) # w1 - dst_w1_weight_scale = dst_w3_w1_weight_scale.narrow( + dst_w1_weight_scale = dst_w3_w1_weight_scale_gpu.narrow( dim=0, start=module.intermediate_size_per_partition, length=module.intermediate_size_per_partition) dst_w1_weight_scale.copy_( w1_weight_scale.view(dst_w1_weight_scale.dtype)) - orig_shape = dst_w3_w1_weight_scale.shape + orig_shape = dst_w3_w1_weight_scale_gpu.shape # trtllm-gen specific block scales preprocessing logics epilogue_tile_m = 128 # FIXME # Get permute indices permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( - dst_w3_w1_weight_scale.view(float4_sf_dtype), + dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), self._cache_permute_indices, epilogue_tile_m, num_elts_per_sf=num_elts_per_sf) # Shuffle the weight according to permute indices w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix( - dst_w3_w1_weight_scale.view(float4_sf_dtype), permute_indices) + dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), permute_indices) # Assert should only be removed during debugging assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed" @@ -1994,52 +2024,62 @@ def load_expert_w3_w1_weight_scale_nvfp4( processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave( w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape)) # Copy the result into device buffer - dst_w3_w1_weight_scale.copy_( + dst_w3_w1_weight_scale_gpu.copy_( processed_w3_w1_weight_scale.view( self.block_scales_dtype).reshape(orig_shape)) + if not dst_on_gpu: + dst_w3_w1_weight_scale.copy_(dst_w3_w1_weight_scale_gpu) + def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor, num_elts_per_sf: int = 16): - device = dst_w2_weight_scale.device - assert device.type == "cuda" + device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w2_weight_scale.device.type == "cuda" + dst_w2_weight_scale_gpu = dst_w2_weight_scale if dst_on_gpu else dst_w2_weight_scale.cuda( + ) + w2_weight_scale = load_weight_shard(w2_weight_scale, module.tp_size, module.tp_rank, TensorParallelMode.ROW, device=device) # Keep weights in device buffer - dst_w2_weight_scale.copy_( - w2_weight_scale.view(dst_w2_weight_scale.dtype)) + dst_w2_weight_scale_gpu.copy_( + w2_weight_scale.view(dst_w2_weight_scale_gpu.dtype)) - orig_shape = dst_w2_weight_scale.shape + orig_shape = dst_w2_weight_scale_gpu.shape # trtllm-gen specific block scales preprocessing logics epilogue_tile_m = 128 # FIXME: read from kernel # Assert should only be removed during debugging - assert dst_w2_weight_scale.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" + assert dst_w2_weight_scale_gpu.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" # Get permute indices permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( - dst_w2_weight_scale.view(float4_sf_dtype), + dst_w2_weight_scale_gpu.view(float4_sf_dtype), self._cache_permute_indices, epilogue_tile_m, num_elts_per_sf=num_elts_per_sf) # Shuffle the weight according to permute indices w_shuffled = torch.ops.trtllm.shuffle_matrix( - dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices) + dst_w2_weight_scale_gpu.view(dtype=float4_sf_dtype), + permute_indices) # Interleave the weight. processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave( w_shuffled) # Copy the result into device buffer - dst_w2_weight_scale.copy_( + dst_w2_weight_scale_gpu.copy_( processed_w2_weight_scale.view( self.block_scales_dtype).reshape(orig_shape)) + if not dst_on_gpu: + dst_w2_weight_scale.copy_(dst_w2_weight_scale_gpu) + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): super().load_quant_scales(module, weights) @@ -2048,6 +2088,40 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): module.fc31_alpha.data, non_blocking=True) + if self.need_load_shared_weights(module): + local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( + ) + # fc2_input_scale has shape (1,), so we don't need to reload that. + # We only need to load fc31_alpha, + # we reuse existing load_all_fp4_weight_scales_and_alphas logic, ignore large weight_scales. + local_shared_fc31_alpha_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.fc31_alpha.data.shape[1:], + dtype=module.fc31_alpha.data.dtype, + device='cpu') + local_shared_fc2_alpha_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.fc2_alpha.data.shape[1:], + dtype=module.fc2_alpha.data.dtype, + device='cpu') + self.load_all_fp4_weight_scales_and_alphas( + module, + weights, + local_shared_load_expert_ids, + None, + None, + local_shared_fc31_alpha_tensors, + local_shared_fc2_alpha_tensors, + ignore_weight_scale=True) + + local_shared_fc31_scale_c = module.fc2_input_scale.data.cpu( + ) * local_shared_fc31_alpha_tensors + + module.register_all_parameter_slot_and_to_fix_weight_fns({ + 'fc31_scale_c': + local_shared_fc31_scale_c, + }) + class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod): @@ -2676,6 +2750,9 @@ def load_expert_w3_w1_weight_scale_mxfp4( w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor): device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w3_w1_weight_scale.device.type == "cuda" + dst_w3_w1_weight_scale_gpu = dst_w3_w1_weight_scale if dst_on_gpu else dst_w3_w1_weight_scale.cuda( + ) alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, @@ -2701,28 +2778,28 @@ def load_expert_w3_w1_weight_scale_mxfp4( device=device) # Keep weights in device buffer - dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale.chunk( + dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale_gpu.chunk( 2, dim=0) dst_w3_weight_scale.copy_( w3_weight_scale.view(dst_w3_weight_scale.dtype)) dst_w1_weight_scale.copy_( w1_weight_scale.view(dst_w1_weight_scale.dtype)) - orig_shape = dst_w3_w1_weight_scale.shape + orig_shape = dst_w3_w1_weight_scale_gpu.shape # trtllm-gen specific block scales preprocessing logics epilogue_tile_m = 128 # FIXME # Get permute indices permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices( - dst_w3_w1_weight_scale.view(float4_sf_dtype), + dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), self._cache_permute_indices, epilogue_tile_m, num_elts_per_sf=32) # Shuffle the weight according to permute indices w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix( - dst_w3_w1_weight_scale.view(float4_sf_dtype), permute_indices) + dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), permute_indices) # Assert should only be removed during debugging assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed" @@ -2730,14 +2807,20 @@ def load_expert_w3_w1_weight_scale_mxfp4( processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave( w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape)) # Copy the result into device buffer - dst_w3_w1_weight_scale.copy_( + dst_w3_w1_weight_scale_gpu.copy_( processed_w3_w1_weight_scale.view( self.block_scales_dtype).reshape(orig_shape)) + if not dst_on_gpu: + dst_w3_w1_weight_scale.copy_(dst_w3_w1_weight_scale_gpu) + def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor): device = torch.device(f"cuda:{torch.cuda.current_device()}") + dst_on_gpu = dst_w2_weight_scale.device.type == "cuda" + dst_w2_weight_scale_gpu = dst_w2_weight_scale if dst_on_gpu else dst_w2_weight_scale.cuda( + ) alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, @@ -2755,35 +2838,39 @@ def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module, device=device) # Keep weights in device buffer - dst_w2_weight_scale.copy_( + dst_w2_weight_scale_gpu.copy_( w2_weight_scale.view(dst_w2_weight_scale.dtype)) - orig_shape = dst_w2_weight_scale.shape + orig_shape = dst_w2_weight_scale_gpu.shape # trtllm-gen specific block scales preprocessing logics epilogue_tile_m = 128 # FIXME: read from kernel # Assert should only be removed during debugging - assert dst_w2_weight_scale.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" + assert dst_w2_weight_scale_gpu.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" # Get permute indices permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices( - dst_w2_weight_scale.view(float4_sf_dtype), + dst_w2_weight_scale_gpu.view(float4_sf_dtype), self._cache_permute_indices, epilogue_tile_m, num_elts_per_sf=32) # Shuffle the weight according to permute indices w_shuffled = torch.ops.trtllm.shuffle_matrix( - dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices) + dst_w2_weight_scale_gpu.view(dtype=float4_sf_dtype), + permute_indices) # Interleave the weight. processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave( w_shuffled) # Copy the result into device buffer - dst_w2_weight_scale.copy_( + dst_w2_weight_scale_gpu.copy_( processed_w2_weight_scale.view( self.block_scales_dtype).reshape(orig_shape)) + if not dst_on_gpu: + dst_w2_weight_scale.copy_(dst_w2_weight_scale_gpu) + class W4A16MXFP4TRTLLMGenFusedMoEMethod(MXFP4WeightTRTLLMGenFusedMoEMethod): pass diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fe95c4cc093..c07e6e99527 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1617,14 +1617,15 @@ def test_fp8_block_scales_4gpus_static_eplb(self): @pytest.mark.skip_less_device(4) @pytest.mark.skip_device_not_contain(["GB200"]) + @parametrize_with_ids("moe_backend", ["WIDEEP"]) @parametrize_with_ids("mtp_nextn", [0, 2]) - def test_bfloat16_4gpus_online_eplb(self, mtp_nextn): + def test_bfloat16_4gpus_online_eplb(self, moe_backend, mtp_nextn): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) num_slots = 80 eplb_config = MoeLoadBalancerConfig(num_slots=num_slots, layer_updates_per_iter=2) pytorch_config = dict(cuda_graph_config=CudaGraphConfig(), - moe_config=MoeConfig(backend="WIDEEP", + moe_config=MoeConfig(backend=moe_backend, load_balancer=eplb_config)) mtp_config = None if mtp_nextn > 0: @@ -1641,14 +1642,15 @@ def test_bfloat16_4gpus_online_eplb(self, mtp_nextn): @pytest.mark.skip_less_device(4) @pytest.mark.skip_device_not_contain(["GB200"]) + @parametrize_with_ids("moe_backend", ["WIDEEP", "TRTLLM"]) @parametrize_with_ids("fp8kv", [True, False]) - def test_nvfp4_4gpus_online_eplb(self, fp8kv): + def test_nvfp4_4gpus_online_eplb(self, moe_backend, fp8kv): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) num_slots = 80 eplb_config = MoeLoadBalancerConfig(num_slots=num_slots, layer_updates_per_iter=2) pytorch_config = dict(cuda_graph_config=CudaGraphConfig(), - moe_config=MoeConfig(backend="WIDEEP", + moe_config=MoeConfig(backend=moe_backend, load_balancer=eplb_config)) if fp8kv: kv_cache_config.dtype = "fp8" @@ -3915,7 +3917,7 @@ def test_eagle3(self, moe_backend, mocker): chat_template_kwargs = dict(reasoning_effort="medium") extra_evaluator_kwargs = { **self.extra_evaluator_kwargs, "chat_template_kwargs": - chat_template_kwargs + chat_template_kwargs } sampling_params = SamplingParams( @@ -3929,6 +3931,42 @@ def test_eagle3(self, moe_backend, mocker): extra_evaluator_kwargs=extra_evaluator_kwargs) + @pytest.mark.skip_less_device(4) + @pytest.mark.skip_device_not_contain(["GB200"]) + @pytest.mark.parametrize( + "kv_cache_dtype", + ["auto", pytest.param("fp8", marks=skip_pre_blackwell)]) + def test_w4_4gpus_online_eplb(self, kv_cache_dtype, mocker): + """Test GPTOSS with online expert parallel load balancer using TRTLLM backend and attention DP.""" + mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192) + mocker.patch.dict(GSM8K.EVALUATE_KWARGS, + {"scores_filter": "exact_match,flexible-extract"}) + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5, + dtype=kv_cache_dtype) + + # Configure online expert parallel load balancer + num_slots = 144 + eplb_config = MoeLoadBalancerConfig(num_slots=num_slots, + layer_updates_per_iter=2) + + pytorch_config = dict(disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(), + moe_config=MoeConfig(backend="TRTLLM", + load_balancer=eplb_config)) + + with LLM(self.MODEL_PATH, + tensor_parallel_size=4, + pipeline_parallel_size=1, + moe_expert_parallel_size=4, + kv_cache_config=kv_cache_config, + enable_attention_dp=True, + **pytorch_config) as llm: + model_name = "GPT-OSS/120B-MXFP4" + task = GSM8K(model_name) + task.evaluate(llm, + extra_evaluator_kwargs=self.extra_evaluator_kwargs) + @skip_pre_hopper class TestEXAONE4(LlmapiAccuracyTestHarness): MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B" From dfa7fa3c5c1e44379ddd5d1e67115d9dd2e654d2 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 3 Nov 2025 11:30:38 +0800 Subject: [PATCH 04/13] optimize memory usage by unmap some weights that might not used again Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/modules/fused_moe/interface.py | 6 + .../modules/fused_moe/moe_load_balancer.py | 108 ++++++++++++++++++ .../_torch/modules/fused_moe/quantization.py | 6 + 3 files changed, 120 insertions(+) diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index b417574d656..0974738cf76 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -292,6 +292,12 @@ def _init_load_balancer( else: self.allreduce = None + def _add_raw_shared_weights_for_unmap(self, + weight_tensors: List[torch.Tensor]): + if self.layer_load_balancer: + self.layer_load_balancer._add_raw_host_weight_for_unmap( + weight_tensors) + def _supports_load_balancer(self) -> bool: """Check if this MoE implementation supports load balancer. diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index 6f0bfd64390..20fe1d9a8be 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -1,3 +1,6 @@ +import ctypes +import gc +import mmap import os import threading from contextlib import nullcontext @@ -18,6 +21,81 @@ from ..multi_stream_utils import do_multi_stream +def advise_tensor_pageout(tensor, mode: str = "dontneed"): + """ + Advise the OS to page out or discard the physical memory pages backing a CPU tensor. + This works only for tensors backed by an mmap'ed file or shared memory. + + Parameters + ---------- + tensor : torch.Tensor + A CPU tensor (usually created via torch.from_file() or numpy.memmap()). + mode : str, optional + "pageout" -> use MADV_PAGEOUT (asynchronous pageout, Linux 4.5+) + "dontneed" -> use MADV_DONTNEED (immediate discard) + + Raises + ------ + ValueError + If the tensor is not on CPU or an invalid mode is given. + OSError + If the madvise() syscall fails (errno will be included). + + Notes + ----- + - Works only on Linux systems. + - This call only gives a *hint* to the kernel: the OS may decide to ignore it. + - Safe to call on mmap-backed tensors (data will be reloaded on next access). + - If called on a malloc-based tensor (not mmap), madvise() simply does nothing + and returns 0 (success) because the virtual address range is anonymous memory. + It does NOT crash or corrupt data. + """ + + libc = ctypes.CDLL("libc.so.6", use_errno=True) + MADV_PAGEOUT = 21 + MADV_DONTNEED = 4 + + if not tensor.device.type == "cpu": + raise ValueError("Only CPU tensors are supported.") + + # Get raw pointer and size in bytes + ptr = tensor.data_ptr() + nbytes = tensor.numel() * tensor.element_size() + + # Only operate on complete pages within the tensor's memory range + # to avoid affecting memory outside the tensor boundaries + page_size = mmap.PAGESIZE + + # Round up to the first complete page boundary inside the tensor + start_aligned = (ptr + page_size - 1) & ~(page_size - 1) + + # Round down to the last complete page boundary inside the tensor + end_ptr = ptr + nbytes + end_aligned = end_ptr & ~(page_size - 1) + + # Calculate the size of complete pages within the tensor + size = end_aligned - start_aligned + + # If there are no complete pages within the tensor, skip madvise + if size <= 0: + return + + # Choose advice mode + if mode == "pageout": + advice = MADV_PAGEOUT + elif mode == "dontneed": + advice = MADV_DONTNEED + else: + raise ValueError("mode must be 'pageout' or 'dontneed'.") + + # Perform madvise() only on complete pages within the tensor + ret = libc.madvise(ctypes.c_void_p(start_aligned), ctypes.c_size_t(size), + ctypes.c_int(advice)) + if ret != 0: + err = ctypes.get_errno() + raise OSError(err, f"madvise() failed with errno={err}") + + def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight: """ Convert a tensor to a MoeWeight object. @@ -80,6 +158,8 @@ def __init__(self, layer_id: int, expert_count: int, self.shared_tensors = {} self.names = [] + self.loaded_shared_weights = [] + def set_shared_memory_base_name(self, shared_memory_base_name): """ Set the shared memory base name for the layer. @@ -160,6 +240,16 @@ def share_host_tensor_with_shape(self, expert_id: int, name: str, def align_size(size: int): return (size + 256 - 1) // 256 * 256 + def add_raw_host_weight_for_unmap(self, + raw_weight_tensors: List[torch.Tensor]): + """ + Add a raw weight (mmapped Tensor) to HostMoeTensorSharer for later madvise to save host memory. + + Args: + raw_weight_tensors: A list of raw weight tensors + """ + self.loaded_shared_weights.extend(raw_weight_tensors) + def finalize_layer_weights(self): self.names = list(sorted(self.name_info.keys())) assert len( @@ -215,6 +305,13 @@ def finalize_layer_weights(self): offset += aligned_size self.shared_tensors = {} + for raw_weight in self.loaded_shared_weights: + advise_tensor_pageout(raw_weight) + + self.loaded_shared_weights = [] + + gc.collect() + def finalize_host_tensor_sharing(self, add_host_weight_fn: Callable = None): """ Finalize the host tensor sharing. @@ -402,6 +499,17 @@ def register_weight_slot(self, local_slot_id: int, name: str, moe_weight = _tensor_to_weight(t) self._add_weight_slot(local_slot_id, name, moe_weight) + def _add_raw_host_weight_for_unmap(self, + raw_weight_tensors: List[torch.Tensor]): + """ + Add a raw weight (mmapped Tensor) to LoadBalancer for later madvise to save host memory. + + Args: + raw_weight_tensors: A list of raw weight tensors + """ + self.host_tensor_sharer.add_raw_host_weight_for_unmap( + raw_weight_tensors) + def _add_host_weight(self, expert_id: int, name: str, host_weight: _tbr.MoeWeight): """ diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 74cec95e865..f248432f37c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -251,6 +251,8 @@ def load_expert_weights_to_dst( self.load_expert_w2_weight(module, w2_weight, dst_w2_weights_tensor[expert_idx]) + module._add_raw_shared_weights_for_unmap( + [w1_weight, w3_weight, w2_weight]) if module.bias: self.load_expert_w3_w1_weight( @@ -259,6 +261,8 @@ def load_expert_weights_to_dst( self.load_expert_w2_weight(module, w2_bias, dst_w2_bias_tensor.data[expert_idx]) + module._add_raw_shared_weights_for_unmap( + [w1_bias, w3_bias, w2_bias]) def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): @@ -1682,6 +1686,8 @@ def load_all_fp4_weight_scales_and_alphas( dst_w3_w1_weight_scale[expert_idx]) self.load_expert_w2_weight_scale_nvfp4( module, w2_weight_scale, dst_w2_weight_scale[expert_idx]) + module._add_raw_shared_weights_for_unmap( + [w1_weight_scale, w3_weight_scale, w2_weight_scale]) self.load_expert_fc31_alpha_nvfp4(w1_weight_scale_2, w3_weight_scale_2, From 5da386d9e2e43fc522a43414994d8cd90aabd638 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:31:37 +0800 Subject: [PATCH 05/13] add EPLB support for cutlass backend and update test Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/modules/fused_moe/create_moe.py | 4 +- .../modules/fused_moe/fused_moe_cute_dsl.py | 13 +- .../modules/fused_moe/fused_moe_cutlass.py | 156 +++++++++++------- .../defs/accuracy/test_llm_api_pytorch.py | 6 +- 4 files changed, 110 insertions(+), 69 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 2ff8a0121f0..6a5790bbfca 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -80,8 +80,8 @@ def create_moe( moe_load_balancer = get_moe_load_balancer() if moe_load_balancer is not None: assert moe_cls in [ - WideEPMoE, TRTLLMGenFusedMoE - ], "MoE Load Balance is only supported in WideEPMoE and TRTLLMGenFusedMoE now." + WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE + ], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE and TRTLLMGenFusedMoE now." if bias: assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 9e5b2a1e947..b90fc35e680 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -141,12 +141,13 @@ def __init__( ) def forward_chunk( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + repeating_info: tuple = (True, True), ) -> torch.Tensor: if isinstance(x, Fp4QuantizedTensor): assert output_dtype is not None diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index fbb487e3423..bfd45583c39 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -9,6 +9,7 @@ from tensorrt_llm.logger import logger from ...distributed import allgather +from ...expert_statistic import ExpertStatistic from ...model_config import ModelConfig from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, ceil_div from .interface import AlltoallMethodType, MoE @@ -99,19 +100,9 @@ def __init__( self.intermediate_size_per_partition = ( (self.intermediate_size_per_partition + 127) // 128) * 128 - self.num_slots = self.num_experts - self.expert_size_per_partition = self.num_experts // self.ep_size - self.initial_global_assignments = [ - (ep_rank * self.num_experts // self.ep_size + local_slot_id) % - self.num_experts for ep_rank in range(self.ep_size) - for local_slot_id in range(self.expert_size_per_partition) - ] - self.slot_start = self.ep_rank * self.expert_size_per_partition - self.slot_end = self.slot_start + self.expert_size_per_partition - self.initial_local_expert_ids = self.initial_global_assignments[ - self.slot_start:self.slot_end] - assert len( - self.initial_local_expert_ids) == self.expert_size_per_partition + # Note: num_slots, expert_size_per_partition, initial_global_assignments, + # slot_start, slot_end, initial_local_expert_ids are all initialized by + # base class's _init_load_balancer() method # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size @@ -260,6 +251,10 @@ def moe_alltoall_backend(self): return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND", "mnnvllatency").strip().lower() + def _supports_load_balancer(self) -> bool: + """CutlassFusedMoE supports load balancer.""" + return True + def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -298,12 +293,13 @@ def create_weights(self): self._check_configs() def forward_chunk( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + repeating_info: tuple = (True, True), ) -> torch.Tensor: if isinstance(x, Fp4QuantizedTensor): assert output_dtype is not None @@ -311,6 +307,10 @@ def forward_chunk( else: output_dtype = x.dtype + is_first_call, is_last_call = repeating_info + + self._load_balancer_start_wait_gpu_stage(is_first_call) + # apply routing token_selected_experts, token_final_scales = self.routing_method.apply( router_logits) @@ -321,6 +321,22 @@ def forward_chunk( assert token_final_scales.dtype == torch.float32 assert token_selected_experts.dtype == torch.int32 + if self.layer_load_balancer: + self._load_balancer_done_wait_gpu_stage(is_first_call) + use_mnnvl = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL + self._load_balancer_update_statistic(token_selected_experts, + is_first_call, is_last_call, + use_mnnvl) + token_selected_slots = self._load_balancer_route( + token_selected_experts, self.use_dp) + else: + token_selected_slots = token_selected_experts + + # If load balancer is disabled, the statistics are collected from expert IDs. + # If load balancer is enabled, the statistics are collected from expert slot IDs. + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) + if self.apply_router_weight_on_input: assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" x = x * token_final_scales.to(x.dtype) @@ -395,7 +411,7 @@ def forward_chunk( tuner_num_tokens = sum(all_rank_num_tokens) else: tuner_num_tokens = x.shape[0] * self.mapping.tp_size - tuner_top_k = token_selected_experts.shape[1] + tuner_top_k = token_selected_slots.shape[1] else: tuner_num_tokens = None tuner_top_k = None @@ -413,32 +429,41 @@ def forward_chunk( # Handle case where token_final_scales might be None (when apply_router_weight_on_input=True) if token_final_scales is None: - token_final_scales = torch.ones_like(token_selected_experts, + token_final_scales = torch.ones_like(token_selected_slots, dtype=torch.float32) if self.moe_alltoall_backend == "mnnvllatency": assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" - alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( - token_selected_experts, None, + if is_last_call: + loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( + ) + else: + loadbalancer_local_statistic_info = None + alltoall_info, gathered_loadbalancer_local_statistic_info = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_slots, loadbalancer_local_statistic_info, self.alltoall_prepare_workspace, runtime_max_tokens_per_rank, self.ep_rank, self.ep_size, - self.num_experts, self.num_experts, top_k) + self.num_experts, self.num_slots, top_k) + if gathered_loadbalancer_local_statistic_info is not None: + gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( + (self.mapping.moe_ep_size, self.num_experts)) + self._load_balancer_update_statistic_with_gathered_statistic( + gathered_loadbalancer_local_statistic_info) if x_sf is not None: x_sf = x_sf.view(x_row, ceil_div(x_col, self.scaling_vector_size)) is_sf_swizzled = False - # Dispatch x, x_sf, token_selected_experts, token_final_scales in one alltoall kernel - x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( - [x, x_sf, token_selected_experts, token_final_scales], + # Dispatch x, x_sf, token_selected_slots, token_final_scales in one alltoall kernel + x, x_sf, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( + [x, x_sf, token_selected_slots, token_final_scales], alltoall_info, self.alltoall_workspace, self.ep_rank, self.ep_size) torch.ops.trtllm.memset_expert_ids( - token_selected_experts, - alltoall_info.recv_rank_count_cumsum, - runtime_max_tokens_per_rank, top_k, self.num_experts, + token_selected_slots, alltoall_info.recv_rank_count_cumsum, + runtime_max_tokens_per_rank, top_k, self.num_slots, self.ep_size) elif self.moe_alltoall_backend == "mnnvlthroughput": # Python MoeAlltoAll path @@ -454,11 +479,11 @@ def forward_chunk( expert_id_payload_index = 2 else: expert_id_payload_index = 1 - payloads.append(token_selected_experts) + payloads.append(token_selected_slots) payloads.append(token_final_scales) recv_tensors = self.moe_a2a.dispatch( - token_selected_experts, + token_selected_slots, payloads, runtime_max_tokens_per_rank, invalid_token_expert_id=self. @@ -467,13 +492,13 @@ def forward_chunk( ) if x_sf is not None: - x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors + x_recv, x_sf_recv, token_selected_slots_recv, token_final_scales_recv = recv_tensors x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1]) else: - x_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors + x_recv, token_selected_slots_recv, token_final_scales_recv = recv_tensors x = x_recv.view(-1, x_recv.shape[-1]) - token_selected_experts = token_selected_experts_recv.view( - -1, token_selected_experts_recv.shape[-1]) + token_selected_slots = token_selected_slots_recv.view( + -1, token_selected_slots_recv.shape[-1]) token_final_scales = token_final_scales_recv.view( -1, token_final_scales_recv.shape[-1]) else: @@ -491,8 +516,8 @@ def forward_chunk( ) == 2, "The hidden states scaling factor should be 2D tensor before allgather" is_sf_swizzled = False - x, x_sf, token_selected_experts, token_final_scales = allgather( - [x, x_sf, token_selected_experts, token_final_scales], + x, x_sf, token_selected_slots, token_final_scales = allgather( + [x, x_sf, token_selected_slots, token_final_scales], self.mapping, dim=0, sizes=None if use_dp_padding else all_rank_num_tokens) @@ -509,7 +534,7 @@ def forward_chunk( output_dtype) final_hidden_states = torch.ops.trtllm.fused_moe( x, - token_selected_experts, + token_selected_slots, token_final_scales, self.w3_w1_weight.view(weight_dtype), self.w3_w1_bias, @@ -546,6 +571,8 @@ def forward_chunk( # Otherwise, the output should be unpacked as a single tensor. final_hidden_states = final_hidden_states[0] + self._load_balancer_start_set_cpu_stage(is_last_call) + # Combine results if using alltoall if self.enable_alltoall: if self.moe_alltoall_backend == "mnnvllatency": @@ -576,6 +603,8 @@ def forward_chunk( f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" ) + self._load_balancer_done_set_cpu_stage(is_last_call) + return final_hidden_states def split_chunk(self, split_token_num: int, split_num_chunks: int): @@ -616,12 +645,15 @@ def forward_impl( 1) // self.moe_max_num_tokens if num_chunks == 1: + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 outputs = self.forward_chunk( x, router_logits, output_dtype, all_rank_num_tokens=all_rank_num_tokens_padded, - use_dp_padding=use_dp_padding) + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call)) outputs = self.reducescatter_or_allreduce( outputs, all_rank_num_tokens=all_rank_num_tokens_padded, @@ -648,12 +680,15 @@ def forward_impl( self.event_dict[EventType.Main].wait() def _forward_chunk(x_, router_logits_, idx): + is_first_call = idx == 0 and self.repeat_idx == 0 + is_last_call = idx == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1 return self.forward_chunk( x_, router_logits_, all_rank_num_tokens=all_rank_num_tokens_list[idx] if self.use_dp else None, - use_dp_padding=use_dp_padding) + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call)) def _reducescatter_or_allreduce(x_, idx): return self.reducescatter_or_allreduce( @@ -665,37 +700,42 @@ def _reducescatter_or_allreduce(x_, idx): # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap for idx_chunk, (x, router_logits) in enumerate( zip(x_list, router_logits_list)): - - if idx_chunk % 2 == 0: - with torch.cuda.stream(self.aux_stream): + if not (self.moe_alltoall_backend == "mnnvllatency"): + if idx_chunk % 2 == 0: + with torch.cuda.stream(self.aux_stream): + outputs = _forward_chunk(x, router_logits, + idx_chunk) + if idx_chunk > 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + else: outputs = _forward_chunk(x, router_logits, idx_chunk) - if idx_chunk > 0: - outputs_list[-1] = _reducescatter_or_allreduce( - outputs_list[-1], idx_chunk - 1) + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) else: outputs = _forward_chunk(x, router_logits, idx_chunk) - with torch.cuda.stream(self.aux_stream): - outputs_list[-1] = _reducescatter_or_allreduce( - outputs_list[-1], idx_chunk - 1) outputs_list.append(outputs) - if num_chunks % 2 == 0: - outputs_list[-1] = _reducescatter_or_allreduce( - outputs_list[-1], -1) - else: - with torch.cuda.stream(self.aux_stream): + if not (self.moe_alltoall_backend == "mnnvllatency"): + if num_chunks % 2 == 0: outputs_list[-1] = _reducescatter_or_allreduce( outputs_list[-1], -1) - with torch.cuda.stream(self.aux_stream): - self.event_dict[EventType.MoeChunkingOverlap].record() - self.event_dict[EventType.MoeChunkingOverlap].wait() + else: + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeChunkingOverlap].record() + self.event_dict[EventType.MoeChunkingOverlap].wait() outputs = torch.cat(outputs_list) if self.use_dp and self.parallel_size > 1: rank = self.mapping.tp_rank outputs = outputs[:all_rank_num_tokens[rank]] + self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 return outputs def forward_fake( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index c07e6e99527..e054ca1d096 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1617,7 +1617,7 @@ def test_fp8_block_scales_4gpus_static_eplb(self): @pytest.mark.skip_less_device(4) @pytest.mark.skip_device_not_contain(["GB200"]) - @parametrize_with_ids("moe_backend", ["WIDEEP"]) + @parametrize_with_ids("moe_backend", ["WIDEEP", "CUTLASS", "TRTLLM"]) @parametrize_with_ids("mtp_nextn", [0, 2]) def test_bfloat16_4gpus_online_eplb(self, moe_backend, mtp_nextn): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) @@ -3917,7 +3917,7 @@ def test_eagle3(self, moe_backend, mocker): chat_template_kwargs = dict(reasoning_effort="medium") extra_evaluator_kwargs = { **self.extra_evaluator_kwargs, "chat_template_kwargs": - chat_template_kwargs + chat_template_kwargs } sampling_params = SamplingParams( @@ -3930,7 +3930,6 @@ def test_eagle3(self, moe_backend, mocker): sampling_params=sampling_params, extra_evaluator_kwargs=extra_evaluator_kwargs) - @pytest.mark.skip_less_device(4) @pytest.mark.skip_device_not_contain(["GB200"]) @pytest.mark.parametrize( @@ -3967,6 +3966,7 @@ def test_w4_4gpus_online_eplb(self, kv_cache_dtype, mocker): task.evaluate(llm, extra_evaluator_kwargs=self.extra_evaluator_kwargs) + @skip_pre_hopper class TestEXAONE4(LlmapiAccuracyTestHarness): MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B" From 29779a37b646bcf8625ccd348192369ce7b3ad41 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:31:54 +0800 Subject: [PATCH 06/13] explicitly mark EPLB unsupport and unverified quant methods Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/modules/fused_moe/quantization.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index f248432f37c..eea4a3d43b0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -156,6 +156,23 @@ class FusedMoEMethodBase(ABC): """ weight_alignment: int = 1 + @classmethod + def _online_eplb_not_supported(cls, module): + if hasattr( + module, "layer_load_balancer" + ) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights( + ): + raise NotImplementedError( + f'{cls.__name__} doesn\'t support online EPLB now') + + @classmethod + def _online_eplb_not_verified(cls, module): + if hasattr( + module, "layer_load_balancer" + ) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights( + ): + logger.warning(f'{cls.__name__} online EPLB is not verified yet') + def need_load_shared_weights(self, module): if hasattr( module, "layer_load_balancer" @@ -516,6 +533,8 @@ def create_weights(self, module: torch.nn.Module): requires_grad=False) module.register_parameter("fc31_input_dequant", fc31_input_dequant) + self._online_eplb_not_supported(module) + self.setup_quant_scales(module) def setup_quant_scales(self, module: torch.nn.Module): @@ -618,6 +637,8 @@ def create_weights(self, module: torch.nn.Module): module.register_parameter("w2_weight_scaling_factor", w2_weight_scaling_factor) + self._online_eplb_not_verified(module) + self.setup_quant_scales(module) def load_weights(self, module: torch.nn.Module, weights: List[Dict], @@ -794,6 +815,9 @@ def create_weights(self, module: torch.nn.Module): super().create_weights(module, weight_dtype, w3_w1_weight_shape, w2_weight_shape) + + self._online_eplb_not_supported(module) + self.setup_quant_scales(module) def setup_quant_scales(self, module: torch.nn.Module): @@ -967,6 +991,9 @@ def create_weights(self, module: torch.nn.Module): super().create_weights(module, weight_dtype, w3_w1_weight_shape, w2_weight_shape) + + self._online_eplb_not_supported(module) + self.setup_quant_scales(module) def setup_quant_scales(self, module: torch.nn.Module): @@ -1364,6 +1391,9 @@ def create_weights(self, module: torch.nn.Module): super().create_weights(module, weight_dtype, w3_w1_weight_shape, w2_weight_shape) + + self._online_eplb_not_supported(module) + self.setup_quant_scales(module) def setup_quant_scales(self, module: torch.nn.Module): @@ -2145,6 +2175,8 @@ def create_weights(self, module: torch.nn.Module): requires_grad=False) module.register_parameter("fc31_scale_c", fc31_scale_c) + self._online_eplb_not_verified(module) + self.setup_quant_scales(module) def load_expert_w3_w1_weight_scale_nvfp4( @@ -2340,6 +2372,8 @@ def create_weights(self, module: torch.nn.Module): self.block_scales_dtype, block_scales_vec_size, self.weight_alignment) + self._online_eplb_not_verified(module) + def load_expert_w3_w1_weight(self, module: torch.nn.Module, w1_weight: torch.Tensor, w3_weight: torch.Tensor, @@ -2506,6 +2540,8 @@ def create_weights(self, module: torch.nn.Module): super().create_weights(module) + self._online_eplb_not_verified(module) + self.setup_quant_scales(module) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): @@ -2547,6 +2583,8 @@ def create_weights(self, module: torch.nn.Module): super().create_weights(module) + self._online_eplb_not_supported(module) + self.setup_quant_scales(module) def load_expert_fc31_input_scale_w4a8_mxfp4_fp8( @@ -2902,6 +2940,8 @@ def create_weights(self, module: torch.nn.Module): super().create_weights(module) + self._online_eplb_not_supported(module) + def load_expert_fc31_input_scale_w4a8_mxfp4_fp8( self, w1_input_scale, w3_input_scale, w2_input_scale, dst_fc31_input_scale: torch.Tensor, From d330870a8adea2c538f89e47395676114a97f809 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:32:10 +0800 Subject: [PATCH 07/13] update test lists Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tests/integration/defs/.test_durations | 8 ++++---- tests/integration/test_lists/qa/llm_function_core.txt | 8 ++++---- .../test_lists/qa/llm_function_core_sanity.txt | 8 ++++---- .../test_lists/test-db/l0_gb200_multi_gpus.yml | 7 +++++-- .../test_lists/test-db/l0_gb300_multi_gpus.yml | 6 ++++-- tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml | 6 ++++-- 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index 62195d1f26b..f5ad45b5fb0 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -225,8 +225,8 @@ "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]": 113.82226522010751, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]": 205.7252635700861, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True]": 213.78996226208983, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]": 292.7267158059985377, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]": 292.5756296711042523, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP]": 292.7267158059985377, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP]": 292.5756296711042523, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[-attention_dp-cuda_graph-overlap_scheduler-torch_compile=False]": 326.1317654890008, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 647.6109309499152, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 184.20976317999884, @@ -252,8 +252,8 @@ "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 90.21807348495349, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]": 175.661773331929, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 360.0003233290044590831, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False]": 360.9275938109494746, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]": 360.0002855450729839504, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False-moe_backend=WIDEEP]": 360.9275938109494746, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP]": 360.0002855450729839504, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 360.0003064870252273977, "accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype": 3600.0004039629711769521, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto]": 360.00032637204276397824, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index d7a8f0d82f6..0bd3446cb75 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -464,10 +464,10 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUT accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 7f6bdfad85e..24bd5d17bd2 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -56,13 +56,13 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baselin accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index eca2f986a8a..3898b101ea3 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -34,12 +34,15 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=CUTLASS] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[fp8] - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_gb300_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb300_multi_gpus.yml index 9e9d2a675cb..6fe1933d422 100644 --- a/tests/integration/test_lists/test-db/l0_gb300_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb300_multi_gpus.yml @@ -39,8 +39,10 @@ l0_gb300_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] ISOLATION - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=CUTLASS] - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False] diff --git a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml index f8dd8b6ea4d..786b03e0e38 100644 --- a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml +++ b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml @@ -93,8 +93,10 @@ l0_rtx_pro_6000: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - # - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] # Verify GDRCopy availability on Blossom pods - # - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] # Verify GDRCopy availability on Blossom pods + # - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP] # Verify GDRCopy availability on Blossom pods + # - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=TRTLLM] # Verify GDRCopy availability on Blossom pods + # - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] # Verify GDRCopy availability on Blossom pods + # - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=CUTLASS] # Verify GDRCopy availability on Blossom pods # - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] # hopper only # - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] From 9f7b8a55161af2c46dbaa09451d541c0e2820156 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Thu, 6 Nov 2025 12:48:02 +0800 Subject: [PATCH 08/13] enable EPLB with MNNVL alltoall for trtllm-gen backend Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_cutlass.py | 10 ++++--- .../modules/fused_moe/fused_moe_trtllm_gen.py | 29 +++++++++++++------ .../modules/fused_moe/fused_moe_wide_ep.py | 4 +-- .../_torch/modules/fused_moe/interface.py | 14 +++++++-- 4 files changed, 39 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index bfd45583c39..c9214287788 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -323,10 +323,12 @@ def forward_chunk( if self.layer_load_balancer: self._load_balancer_done_wait_gpu_stage(is_first_call) - use_mnnvl = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL - self._load_balancer_update_statistic(token_selected_experts, - is_first_call, is_last_call, - use_mnnvl) + ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency" + self._load_balancer_update_statistic( + token_selected_experts, + is_first_call, + is_last_call, + ignore_allreduce=ignore_allreduce) token_selected_slots = self._load_balancer_route( token_selected_experts, self.use_dp) else: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index e8fcdd959d6..4c3c199ff2d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -341,6 +341,8 @@ def forward_impl( x_col = x.shape[1] token_count = x.shape[0] alltoall_info = None + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 if post_quant_comm: token_selected_experts, token_final_scales = self.routing_method.apply( @@ -352,18 +354,17 @@ def forward_impl( # Apply load balancer routing if available if self.layer_load_balancer: # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking) - is_first_call = self.repeat_idx == 0 - is_last_call = self.repeat_idx == self.repeat_count - 1 # Start GPU stage for first call self._load_balancer_start_wait_gpu_stage(is_first_call) self._load_balancer_done_wait_gpu_stage(is_first_call) - # Update load balancer statistics (TRTLLMGenFusedMoE doesn't use MNNVL) - self._load_balancer_update_statistic(token_selected_experts, - is_first_call, - is_last_call, - use_mnnvl=False) + ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency" + self._load_balancer_update_statistic( + token_selected_experts, + is_first_call, + is_last_call, + ignore_allreduce=ignore_allreduce) # Route tokens to slots token_selected_slots = self._load_balancer_route( @@ -393,9 +394,14 @@ def forward_impl( if self.moe_alltoall_backend == "mnnvllatency": assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" - alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + if is_last_call: + loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( + ) + else: + loadbalancer_local_statistic_info = None + alltoall_info, gathered_loadbalancer_local_statistic_info = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( token_selected_experts, - None, + loadbalancer_local_statistic_info, self.alltoall_prepare_workspace, runtime_max_tokens_per_rank, self.ep_rank, @@ -404,6 +410,11 @@ def forward_impl( self.num_slots, top_k, ) + if gathered_loadbalancer_local_statistic_info is not None: + gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( + (self.mapping.moe_ep_size, self.num_experts)) + self._load_balancer_update_statistic_with_gathered_statistic( + gathered_loadbalancer_local_statistic_info) if x_sf is not None: x_sf = x_sf.view(x_row, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 0496e4172ca..f77dbc01113 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -430,10 +430,10 @@ def forward_chunk( if self.layer_load_balancer: self._load_balancer_done_wait_gpu_stage(is_first_call) - use_mnnvl = use_all_to_all and self.alltoall_method_type == AlltoallMethodType.MNNVL + ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency" self._load_balancer_update_statistic(token_selected_experts, is_first_call, is_last_call, - use_mnnvl) + ignore_allreduce) token_selected_slots = self._load_balancer_route( token_selected_experts, self.use_dp) else: diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 0974738cf76..5cd45984183 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -328,10 +328,18 @@ def _load_balancer_update_statistic(self, token_selected_experts: torch.Tensor, is_first_call: bool, is_last_call: bool, - use_mnnvl: bool = False): - """Update load balancer statistics.""" + ignore_allreduce: bool = False): + """ + Update load balancer statistics. + + Args: + token_selected_experts: The selected experts of all tokens, has shape of [tokenCount * topK] + is_first_call: Whether this is the first call for the same weights + is_last_call: Whether this is the last call for the same weights + ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. MnnvlLatency supports this. + """ if self.layer_load_balancer: - if use_mnnvl: + if ignore_allreduce: self.layer_load_balancer.update_local_statistic( token_selected_experts, is_first_stage=is_first_call, From 95bfad204ac2026edc33110906982a2ffdb54211 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:55:52 +0800 Subject: [PATCH 09/13] address rebase test list issue Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tests/integration/test_lists/qa/llm_function_nim.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_lists/qa/llm_function_nim.txt b/tests/integration/test_lists/qa/llm_function_nim.txt index 92d03850317..4bbdf92f6ad 100644 --- a/tests/integration/test_lists/qa/llm_function_nim.txt +++ b/tests/integration/test_lists/qa/llm_function_nim.txt @@ -277,10 +277,10 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp= accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False-moe_backend=WIDEEP] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] From e5726a0ac8fd507d2f66b23d99176c613196b356 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:41:45 +0800 Subject: [PATCH 10/13] refactor load balancer code Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_trtllm_gen.py | 49 +++++++++---------- .../_torch/modules/fused_moe/interface.py | 36 ++++++++------ .../modules/fused_moe/moe_load_balancer.py | 13 +++-- .../_torch/modules/fused_moe/quantization.py | 22 ++++----- 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 4c3c199ff2d..d41c2a7a74a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -341,42 +341,38 @@ def forward_impl( x_col = x.shape[1] token_count = x.shape[0] alltoall_info = None + # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking) is_first_call = self.repeat_idx == 0 is_last_call = self.repeat_idx == self.repeat_count - 1 if post_quant_comm: + # Start GPU stage for first call + self._load_balancer_start_wait_gpu_stage(is_first_call) token_selected_experts, token_final_scales = self.routing_method.apply( router_logits) token_selected_experts = token_selected_experts.to(torch.int32) if token_final_scales is not None: token_final_scales = token_final_scales.to(torch.bfloat16) - # Apply load balancer routing if available - if self.layer_load_balancer: - # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking) + self._load_balancer_done_wait_gpu_stage(is_first_call) - # Start GPU stage for first call - self._load_balancer_start_wait_gpu_stage(is_first_call) - self._load_balancer_done_wait_gpu_stage(is_first_call) - - ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency" - self._load_balancer_update_statistic( - token_selected_experts, - is_first_call, - is_last_call, - ignore_allreduce=ignore_allreduce) + ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency" + self._load_balancer_update_statistic( + token_selected_experts, + is_first_call, + is_last_call, + ignore_allreduce=ignore_allreduce) - # Route tokens to slots - token_selected_slots = self._load_balancer_route( - token_selected_experts, self.use_dp) + # Route tokens to slots + token_selected_slots = self._load_balancer_route( + token_selected_experts, self.use_dp) - # Update expert statistics - ExpertStatistic.set_layer(self.layer_idx) - ExpertStatistic.maybe_add_info(self.num_slots, - token_selected_slots) + # Update expert statistics + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) - # Use routed slots for subsequent processing - token_selected_experts = token_selected_slots + # Use routed slots for subsequent processing + token_selected_experts = token_selected_slots x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x) @@ -756,6 +752,9 @@ def forward_impl( "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes." ) + # Handle load balancer CPU stage if needed + self._load_balancer_start_set_cpu_stage(is_last_call) + # Combine results if using alltoall if self.enable_alltoall: if self.moe_alltoall_backend == "mnnvllatency": @@ -803,11 +802,7 @@ def forward_impl( use_dp_padding=use_dp_padding, ) - # Handle load balancer CPU stage if needed - if self.layer_load_balancer: - is_last_call = self.repeat_idx == self.repeat_count - 1 - self._load_balancer_start_set_cpu_stage(is_last_call) - self._load_balancer_done_set_cpu_stage(is_last_call) + self._load_balancer_done_set_cpu_stage(is_last_call) if use_dp_padding: rank = self.mapping.tp_rank diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 5cd45984183..abaa20a8bb3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -227,7 +227,9 @@ def _init_load_balancer( ] # Setup load balancer if available - if moe_load_balancer and self._supports_load_balancer(): + if moe_load_balancer: + assert self._supports_load_balancer() + assert self.use_dp and self.parallel_size > 1, "Load Balancer should be only used with ADP and EP > 1" assert moe_load_balancer_config is not None top_k = self.routing_method.experts_per_token self.expert_size_per_partition = moe_load_balancer_config.num_local_slots @@ -282,8 +284,7 @@ def _init_load_balancer( self.initial_local_expert_ids) == self.expert_size_per_partition # Setup AllReduce for dynamic routing if needed - if (self.layer_load_balancer - and not self.layer_load_balancer.is_static_routing()): + if self._using_dynamic_load_balancer(): from tensorrt_llm.functional import AllReduceStrategy from ...distributed import AllReduce @@ -294,7 +295,7 @@ def _init_load_balancer( def _add_raw_shared_weights_for_unmap(self, weight_tensors: List[torch.Tensor]): - if self.layer_load_balancer: + if self._using_dynamic_load_balancer(): self.layer_load_balancer._add_raw_host_weight_for_unmap( weight_tensors) @@ -305,6 +306,12 @@ def _supports_load_balancer(self) -> bool: """ return False + def _using_dynamic_load_balancer(self) -> bool: + """Check if this MoE is using dynamic load balancer.""" + if self.layer_load_balancer: + return self.layer_load_balancer.is_dynamic_routing() + return False + def _get_load_balancer_aux_stream(self) -> Optional[torch.cuda.Stream]: """Get auxiliary stream for load balancer from aux_stream_dict. @@ -316,12 +323,12 @@ def _get_load_balancer_aux_stream(self) -> Optional[torch.cuda.Stream]: def _load_balancer_start_wait_gpu_stage(self, is_first_call: bool): """Start waiting for GPU stage in load balancer.""" - if self.layer_load_balancer and is_first_call: + if self._using_dynamic_load_balancer() and is_first_call: self.layer_load_balancer.start_wait_gpu_stage() def _load_balancer_done_wait_gpu_stage(self, is_first_call: bool): """Mark GPU wait stage as done in load balancer.""" - if self.layer_load_balancer and is_first_call: + if self._using_dynamic_load_balancer() and is_first_call: self.layer_load_balancer.done_wait_gpu_stage() def _load_balancer_update_statistic(self, @@ -338,7 +345,7 @@ def _load_balancer_update_statistic(self, is_last_call: Whether this is the last call for the same weights ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. MnnvlLatency supports this. """ - if self.layer_load_balancer: + if self._using_dynamic_load_balancer(): if ignore_allreduce: self.layer_load_balancer.update_local_statistic( token_selected_experts, @@ -362,32 +369,31 @@ def _load_balancer_route(self, token_selected_experts: torch.Tensor, def _load_balancer_start_set_cpu_stage(self, is_last_call: bool): """Start CPU stage in load balancer.""" - if self.layer_load_balancer and is_last_call: + if self._using_dynamic_load_balancer() and is_last_call: self.layer_load_balancer.start_set_cpu_stage() def _load_balancer_done_set_cpu_stage(self, is_last_call: bool): """Mark CPU stage as done in load balancer.""" - if self.layer_load_balancer and is_last_call: + if self._using_dynamic_load_balancer() and is_last_call: self.layer_load_balancer.done_set_cpu_stage() def _load_balancer_get_local_statistic_tensor(self): """Get local statistic tensor from load balancer.""" - if (self.layer_load_balancer - and not self.layer_load_balancer.is_static_routing()): + if self._using_dynamic_load_balancer(): return self.layer_load_balancer.get_local_statistic_tensor() return None def _load_balancer_update_statistic_with_gathered_statistic( self, gathered_statistic): """Update load balancer with gathered statistics.""" - if self.layer_load_balancer: + if self._using_dynamic_load_balancer(): self.layer_load_balancer.update_statistic_with_gathered_statistic( gathered_statistic) def register_parameter_weight_slot_fn(self, weight_name: str, local_slot_id: int): """Register parameter weight slot function for load balancer.""" - if not self.layer_load_balancer: + if not self._using_dynamic_load_balancer(): return assert hasattr( @@ -399,7 +405,7 @@ def register_parameter_weight_slot_fn(self, weight_name: str, def register_to_fix_weight_fn(self, weight_name: str): """Register weight fixing function for load balancer.""" - if not self.layer_load_balancer: + if not self._using_dynamic_load_balancer(): return assert hasattr( @@ -424,7 +430,7 @@ def register_to_fix_weight_fn(self, weight_name: str): def register_all_parameter_slot_and_to_fix_weight_fns( self, weight_and_tensor_dict: Dict[str, torch.Tensor]): """Register all parameter slot and weight fixing functions for load balancer.""" - if not self.layer_load_balancer: + if not self._using_dynamic_load_balancer(): return # Register weight functions for each local slot diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index 20fe1d9a8be..907db7d79ab 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -459,9 +459,12 @@ def get_repeat_count(self): def is_static_routing(self): return not self.updates_enabled - def need_load_shared_weights(self): + def is_dynamic_routing(self): return self.updates_enabled + def need_load_shared_weights(self): + return self.is_dynamic_routing() + def set_shared_memory_base_name(self, shared_memory_base_name): """ Set the shared memory base name for the layer. @@ -807,8 +810,9 @@ def route(self, Returns: A tensor of routed slot IDs """ - assert self.func_called_count["done_wait_gpu_stage"] == 1 - self.func_called_count["route"] += 1 + if self.is_dynamic_routing(): + assert self.func_called_count["done_wait_gpu_stage"] == 1 + self.func_called_count["route"] += 1 return torch.ops.trtllm.moe_load_balance_routing( token_selected_experts, offset_by_ep_rank, self.single_layer_load_balancer_ptr) @@ -881,6 +885,9 @@ def is_static_routing(self): # if we don't update, then it is statistic routing. return self.layer_updates_per_iter == 0 + def is_dynamic_routing(self): + return not self.is_static_routing() + def _setup_mpi_comm(self): global_mpi_comm = tensorrt_llm.mpi_comm() shared_mpi_comm = global_mpi_comm.Split_type( diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index eea4a3d43b0..d7abaeaa071 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -157,30 +157,26 @@ class FusedMoEMethodBase(ABC): weight_alignment: int = 1 @classmethod - def _online_eplb_not_supported(cls, module): + def need_load_shared_weights(cls, module): if hasattr( module, "layer_load_balancer" ) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights( ): + return True + else: + return False + + @classmethod + def _online_eplb_not_supported(cls, module): + if cls.need_load_shared_weights(module): raise NotImplementedError( f'{cls.__name__} doesn\'t support online EPLB now') @classmethod def _online_eplb_not_verified(cls, module): - if hasattr( - module, "layer_load_balancer" - ) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights( - ): + if cls.need_load_shared_weights(module): logger.warning(f'{cls.__name__} online EPLB is not verified yet') - def need_load_shared_weights(self, module): - if hasattr( - module, "layer_load_balancer" - ) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights( - ): - return True - return False - def create_weights( self, module: torch.nn.Module, From eaa3347e478df5339fcde78edaeb117d9252ee26 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:44:09 +0800 Subject: [PATCH 11/13] enable EPLB for mnnvlthroughput Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py | 6 +++--- .../_torch/modules/fused_moe/fused_moe_trtllm_gen.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index c9214287788..315c2443c0c 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -156,7 +156,7 @@ def __init__( mapping=self.mapping, max_num_tokens=model_config.max_num_tokens, top_k=self.routing_method.experts_per_token, - num_experts=self.num_experts, + num_experts=self.num_slots, workspace_size_per_rank=workspace_mb * 1024 * 1024, ) else: @@ -702,7 +702,7 @@ def _reducescatter_or_allreduce(x_, idx): # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap for idx_chunk, (x, router_logits) in enumerate( zip(x_list, router_logits_list)): - if not (self.moe_alltoall_backend == "mnnvllatency"): + if not (self.alltoall_method_type == AlltoallMethodType.MNNVL): if idx_chunk % 2 == 0: with torch.cuda.stream(self.aux_stream): outputs = _forward_chunk(x, router_logits, @@ -720,7 +720,7 @@ def _reducescatter_or_allreduce(x_, idx): outputs_list.append(outputs) - if not (self.moe_alltoall_backend == "mnnvllatency"): + if not (self.alltoall_method_type == AlltoallMethodType.MNNVL): if num_chunks % 2 == 0: outputs_list[-1] = _reducescatter_or_allreduce( outputs_list[-1], -1) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index d41c2a7a74a..044d461789f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -133,7 +133,7 @@ def __init__( mapping=self.mapping, max_num_tokens=model_config.max_num_tokens, top_k=self.routing_method.experts_per_token, - num_experts=self.num_experts, + num_experts=self.num_slots, workspace_size_per_rank=workspace_mb * 1024 * 1024, ) else: From 51d9e38ff4d983bb0079aa8b20b3c113fb31079b Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:59:32 +0800 Subject: [PATCH 12/13] address review issues Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index 907db7d79ab..82289540980 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -243,7 +243,7 @@ def align_size(size: int): def add_raw_host_weight_for_unmap(self, raw_weight_tensors: List[torch.Tensor]): """ - Add a raw weight (mmapped Tensor) to HostMoeTensorSharer for later madvise to save host memory. + Add a raw weight (mmapped Tensor) to HostMoeTensorSharer for later `madvise` to save host memory. Args: raw_weight_tensors: A list of raw weight tensors @@ -505,7 +505,7 @@ def register_weight_slot(self, local_slot_id: int, name: str, def _add_raw_host_weight_for_unmap(self, raw_weight_tensors: List[torch.Tensor]): """ - Add a raw weight (mmapped Tensor) to LoadBalancer for later madvise to save host memory. + Add a raw weight (mmapped Tensor) to LoadBalancer for later `madvise` to save host memory. Args: raw_weight_tensors: A list of raw weight tensors From 83782760e8e658d3e74d0f7315260221cbece487 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Tue, 11 Nov 2025 18:51:53 +0800 Subject: [PATCH 13/13] also add moe_alltoall_backend for WideEP backend Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index f77dbc01113..898a867b286 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -1,4 +1,5 @@ import os +from functools import cached_property from typing import Dict, List, Optional, Tuple, Union import torch @@ -245,6 +246,12 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled + @cached_property + def moe_alltoall_backend(self): + # "mnnvllatency" (default) or "mnnvlthroughput" + return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND", + "mnnvllatency").strip().lower() + def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: num_rows = sum(all_rank_num_tokens) return (num_rows + self.moe_max_num_tokens -