Skip to content
2 changes: 1 addition & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def __init__(
)

if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if self.config.tp_comm_overlap and parallel_mode != "duplicated":
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
Expand Down
8 changes: 6 additions & 2 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,13 @@ def __init__(
if self.config.gated_linear_unit:
ffn_hidden_size *= 2

# Use moe_latent_size only for routed experts. 'is_expert' is false for
# shared_experts.
use_latent_size = (self.config.moe_latent_size is not None) and is_expert

self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
self.input_size if not use_latent_size else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
Expand All @@ -126,7 +130,7 @@ def __init__(
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
self.config.hidden_size if not use_latent_size else self.config.moe_latent_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
Expand Down
11 changes: 9 additions & 2 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def __init__(
assert (
config.add_bias_linear == False
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
assert (
config.moe_latent_size is None
), "MoE latent projection not supported in GroupedMLP yet."

self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
Expand Down Expand Up @@ -778,7 +781,7 @@ def __init__(
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.num_local_experts,
self.input_size,
self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
Expand All @@ -799,7 +802,11 @@ def __init__(
submodules.linear_fc2,
self.num_local_experts,
self.config.moe_ffn_hidden_size,
self.config.hidden_size,
(
self.config.hidden_size
if self.config.moe_latent_size is None
else self.config.moe_latent_size
),
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
Expand Down
42 changes: 40 additions & 2 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
try:
import transformer_engine as te # pylint: disable=unused-import

from megatron.core.extensions.transformer_engine import te_checkpoint
from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint

HAVE_TE = True
except ImportError:
Expand Down Expand Up @@ -120,9 +120,35 @@ def __init__(
and "shared_experts" in config.recompute_modules
)

# Initialize router
# Initialize router.
self.router = TopKRouter(config=self.config, pg_collection=pg_collection)

# Initialize latent projections.
if self.config.moe_latent_size:
assert HAVE_TE, "TransformerEngine is required for MoE latent projections."
self.fc1_latent_proj = TELinear(
self.config.hidden_size,
self.config.moe_latent_size,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
skip_weight_param_allocation=False,
is_expert=False,
)
self.fc2_latent_proj = TELinear(
self.config.moe_latent_size,
self.config.hidden_size,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
skip_weight_param_allocation=False,
is_expert=False,
)

# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
Expand Down Expand Up @@ -176,6 +202,12 @@ def router_and_preprocess(self, hidden_states: torch.Tensor):
"""
residual = hidden_states
probs, routing_map = self.router(hidden_states)
# Project the hidden_states from hidden dimension down to latent dimenion.
if self.config.moe_latent_size:
assert (
not self.shared_expert_overlap
), "Shared expert overlap not supported when MoE latent projections are used."
hidden_states, _ = self.fc1_latent_proj(hidden_states)
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
hidden_states, routing_map, probs
)
Expand Down Expand Up @@ -243,6 +275,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
"""
output = self.token_dispatcher.token_combine(output)
output = self.token_dispatcher.combine_postprocess(output)
# Project the output back from latent dimension to hidden dimension after combine
# in latent dimension.
if self.config.moe_latent_size:
output, _ = self.fc2_latent_proj(output)
if shared_expert_output is not None:
output = output + shared_expert_output
return output
Expand Down Expand Up @@ -274,7 +310,9 @@ def custom_forward(hidden_states):
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
output = self.combine(output, shared_expert_output)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output will be in latent dimension, while shared_expert_output will be in hidden dimension here. We may have to move self.fc2_latent_proj inside self.combine before the addition of output and shared_expert_output.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, moved.


return output, mlp_bias

if self.moe_layer_recompute:
Expand Down
17 changes: 10 additions & 7 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,6 @@ class TransformerConfig(ModelParallelConfig):
A list of integers: Defines a custom pattern where 1 means skip RoPE and 0 means apply RoPE.
For example, [0,1,1,0] means: apply RoPE, skip RoPE, skip RoPE, apply RoPE."""

moe_deepep_num_sms: int = 20
"""Number of SMs to use for DeepEP."""

moe_hybridep_num_sms: int = 16
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
16 SMs can generally achieve good bandwidth."""

####################
# initialization
####################
Expand Down Expand Up @@ -609,6 +602,16 @@ class TransformerConfig(ModelParallelConfig):
moe_apply_probs_on_input: bool = False
"""Apply probs on input of experts instead of applying after activation and glu."""

moe_latent_size: Optional[int] = None
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""

moe_deepep_num_sms: int = 20
"""Number of SMs to use for DeepEP."""

moe_hybridep_num_sms: int = 16
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
16 SMs can generally achieve good bandwidth."""

##################
# Context Parallel
##################
Expand Down
11 changes: 11 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,13 @@ def validate_args(args, defaults={}):
args.recompute_granularity != 'full'
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'

# MoE latent projections
if args.moe_latent_size is not None:
assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero."
assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models."
assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models."
assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM."

if args.tiktoken_special_tokens and not args.tokenizer_special_tokens:
warn_rank_0(
"--tiktoken-special-tokens argument is deprecated and will be removed soon. "
Expand Down Expand Up @@ -1355,6 +1362,8 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)

kw_args['moe_latent_size'] = args.moe_latent_size

if args.te_precision_config_file:
assert not 'quant_recipe' in kw_args, "Quantization recipe already configured."
# TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility
Expand Down Expand Up @@ -1743,6 +1752,8 @@ def _add_network_size_args(parser):
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
group.add_argument('--moe-latent-size', type=int, default=None,
help='Latent projection dimension for MoE. If None, MoE latent projections are not used.')
return parser


Expand Down
3 changes: 3 additions & 0 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('heterogeneous_layers_config_path', force=True)
_set_arg('heterogeneous_layers_config_encoded_json', force=True)

# MoE latent projection.
_set_arg('moe_latent_size', force=True)

# Tokenizer args.
_set_arg('tokenizer_type', force=True)
# Using checkpoint version might not always be safe (e.g., if running on different cluster).
Expand Down
19 changes: 15 additions & 4 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False):
shared_expert_ffn_hidden_size, num_experts_routed_to,
moe_latent_size=None, swiglu=False):
"""Calculate FLOPs for an MoE layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
if moe_latent_size is None:
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
else:
# Routed experts run on moe_latent_size.
routed_flops = (4 * batch_size * seq_len * moe_latent_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
# Up proj and down proj.
routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size)
shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor
return routed_flops + shared_flops

Expand Down Expand Up @@ -232,6 +240,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
num_attn_heads=32, gqa=True,
gqa_groups=8, kv_channels=None,
mlp_expansion=4.0, swiglu=False,
moe_latent_size=None,
moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1,
vocab_size=256000):
"""Calculate total FLOPs for the hybrid model."""
Expand All @@ -244,7 +253,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
mamba_state_dim, mamba_head_dim,
mamba_num_groups, mamba_num_heads) +
num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) +
shared_expert_ffn_hidden_size, num_experts_routed_to,
moe_latent_size, swiglu) +
(2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
)
return flops_fwd * 3
Expand Down Expand Up @@ -449,6 +459,7 @@ def transformer_flops():
kv_channels=args.kv_channels,
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
swiglu=args.swiglu,
moe_latent_size=args.moe_latent_size,
moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None
else args.ffn_hidden_size),
shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None
Expand Down
Loading