diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 32845980c14..a7259570077 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -535,7 +535,7 @@ def __init__( ) if is_te_min_version("0.8.0"): - if self.config.tp_comm_overlap: + if self.config.tp_comm_overlap and parallel_mode != "duplicated": if is_te_min_version("1.5.0"): # Use old overlap flags if they were supplied instead extra_kwargs["ub_overlap_ag"] = ( diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 9602beb2f71..821177e1b06 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -104,9 +104,13 @@ def __init__( if self.config.gated_linear_unit: ffn_hidden_size *= 2 + # Use moe_latent_size only for routed experts. 'is_expert' is false for + # shared_experts. + use_latent_size = (self.config.moe_latent_size is not None) and is_expert + self.linear_fc1 = build_module( submodules.linear_fc1, - self.input_size, + self.input_size if not use_latent_size else self.config.moe_latent_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, @@ -126,7 +130,7 @@ def __init__( self.linear_fc2 = build_module( submodules.linear_fc2, self.config.ffn_hidden_size, - self.config.hidden_size, + self.config.hidden_size if not use_latent_size else self.config.moe_latent_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 68a3d53d2be..9ea26e3e2ee 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -118,6 +118,9 @@ def __init__( assert ( config.add_bias_linear == False ), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead." + assert ( + config.moe_latent_size is None + ), "MoE latent projection not supported in GroupedMLP yet." self.expert_parallel = config.expert_model_parallel_size > 1 if self.config.gated_linear_unit: @@ -778,7 +781,7 @@ def __init__( self.linear_fc1 = build_module( submodules.linear_fc1, self.num_local_experts, - self.input_size, + self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, @@ -799,7 +802,11 @@ def __init__( submodules.linear_fc2, self.num_local_experts, self.config.moe_ffn_hidden_size, - self.config.hidden_size, + ( + self.config.hidden_size + if self.config.moe_latent_size is None + else self.config.moe_latent_size + ), config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index e3de8220a54..fe7e224255d 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -23,7 +23,7 @@ try: import transformer_engine as te # pylint: disable=unused-import - from megatron.core.extensions.transformer_engine import te_checkpoint + from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint HAVE_TE = True except ImportError: @@ -120,9 +120,35 @@ def __init__( and "shared_experts" in config.recompute_modules ) - # Initialize router + # Initialize router. self.router = TopKRouter(config=self.config, pg_collection=pg_collection) + # Initialize latent projections. + if self.config.moe_latent_size: + assert HAVE_TE, "TransformerEngine is required for MoE latent projections." + self.fc1_latent_proj = TELinear( + self.config.hidden_size, + self.config.moe_latent_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) + self.fc2_latent_proj = TELinear( + self.config.moe_latent_size, + self.config.hidden_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) + # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( @@ -176,6 +202,12 @@ def router_and_preprocess(self, hidden_states: torch.Tensor): """ residual = hidden_states probs, routing_map = self.router(hidden_states) + # Project the hidden_states from hidden dimension down to latent dimenion. + if self.config.moe_latent_size: + assert ( + not self.shared_expert_overlap + ), "Shared expert overlap not supported when MoE latent projections are used." + hidden_states, _ = self.fc1_latent_proj(hidden_states) hidden_states, probs = self.token_dispatcher.dispatch_preprocess( hidden_states, routing_map, probs ) @@ -243,6 +275,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten """ output = self.token_dispatcher.token_combine(output) output = self.token_dispatcher.combine_postprocess(output) + # Project the output back from latent dimension to hidden dimension after combine + # in latent dimension. + if self.config.moe_latent_size: + output, _ = self.fc2_latent_proj(output) if shared_expert_output is not None: output = output + shared_expert_output return output @@ -274,7 +310,9 @@ def custom_forward(hidden_states): hidden_states, probs, residual = self.router_and_preprocess(hidden_states) dispatched_input, probs = self.dispatch(hidden_states, probs) output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) + assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}" output = self.combine(output, shared_expert_output) + return output, mlp_bias if self.moe_layer_recompute: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bd384d1ad93..ffc82bc7980 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -209,13 +209,6 @@ class TransformerConfig(ModelParallelConfig): A list of integers: Defines a custom pattern where 1 means skip RoPE and 0 means apply RoPE. For example, [0,1,1,0] means: apply RoPE, skip RoPE, skip RoPE, apply RoPE.""" - moe_deepep_num_sms: int = 20 - """Number of SMs to use for DeepEP.""" - - moe_hybridep_num_sms: int = 16 - """Number of SMs to use for HybridEP. In pure NVL scenarios, - 16 SMs can generally achieve good bandwidth.""" - #################### # initialization #################### @@ -609,6 +602,16 @@ class TransformerConfig(ModelParallelConfig): moe_apply_probs_on_input: bool = False """Apply probs on input of experts instead of applying after activation and glu.""" + moe_latent_size: Optional[int] = None + """Latent projection dimension for MoE. If None, MoE latent projections are not used.""" + + moe_deepep_num_sms: int = 20 + """Number of SMs to use for DeepEP.""" + + moe_hybridep_num_sms: int = 16 + """Number of SMs to use for HybridEP. In pure NVL scenarios, + 16 SMs can generally achieve good bandwidth.""" + ################## # Context Parallel ################## diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index f4fe8f1159b..c51d66dd68a 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1249,6 +1249,13 @@ def validate_args(args, defaults={}): args.recompute_granularity != 'full' ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' + # MoE latent projections + if args.moe_latent_size is not None: + assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero." + assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models." + assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models." + assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM." + if args.tiktoken_special_tokens and not args.tokenizer_special_tokens: warn_rank_0( "--tiktoken-special-tokens argument is deprecated and will be removed soon. " @@ -1355,6 +1362,8 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['use_kitchen'] = True kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number) + kw_args['moe_latent_size'] = args.moe_latent_size + if args.te_precision_config_file: assert not 'quant_recipe' in kw_args, "Quantization recipe already configured." # TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility @@ -1743,6 +1752,8 @@ def _add_network_size_args(parser): 'We compute the average of the MTP losses across all depths, ' 'and multiply it the scaling factor to obtain the overall MTP loss, ' 'which serves as an additional training objective.') + group.add_argument('--moe-latent-size', type=int, default=None, + help='Latent projection dimension for MoE. If None, MoE latent projections are not used.') return parser diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index eb23e7cc092..853f18f8358 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1341,6 +1341,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('heterogeneous_layers_config_path', force=True) _set_arg('heterogeneous_layers_config_encoded_json', force=True) + # MoE latent projection. + _set_arg('moe_latent_size', force=True) + # Tokenizer args. _set_arg('tokenizer_type', force=True) # Using checkpoint version might not always be safe (e.g., if running on different cluster). diff --git a/megatron/training/training.py b/megatron/training/training.py index a81177363d2..d4f5d468f5a 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -180,11 +180,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, - shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False): + shared_expert_ffn_hidden_size, num_experts_routed_to, + moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - routed_flops = (4 * batch_size * seq_len * hidden_size * - moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + if moe_latent_size is None: + routed_flops = (4 * batch_size * seq_len * hidden_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + else: + # Routed experts run on moe_latent_size. + routed_flops = (4 * batch_size * seq_len * moe_latent_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + # Up proj and down proj. + routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops @@ -232,6 +240,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size, num_attn_heads=32, gqa=True, gqa_groups=8, kv_channels=None, mlp_expansion=4.0, swiglu=False, + moe_latent_size=None, moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1, vocab_size=256000): """Calculate total FLOPs for the hybrid model.""" @@ -244,7 +253,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, - shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) + + shared_expert_ffn_hidden_size, num_experts_routed_to, + moe_latent_size, swiglu) + (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation ) return flops_fwd * 3 @@ -449,6 +459,7 @@ def transformer_flops(): kv_channels=args.kv_channels, mlp_expansion=args.ffn_hidden_size / args.hidden_size, swiglu=args.swiglu, + moe_latent_size=args.moe_latent_size, moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size), shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None