Skip to content

Changes to support latent MoEs#2296

Merged
pablo-garay merged 9 commits intoNVIDIA:mainfrom
deepakn94:dnarayanan/latent_moe
Dec 8, 2025
Merged

Changes to support latent MoEs#2296
pablo-garay merged 9 commits intoNVIDIA:mainfrom
deepakn94:dnarayanan/latent_moe

Conversation

@deepakn94
Copy link
Copy Markdown
Contributor

No description provided.

@deepakn94 deepakn94 requested review from a team as code owners November 19, 2025 02:16
@deepakn94 deepakn94 added this to the Core 0.16 milestone Nov 19, 2025
@deepakn94 deepakn94 self-assigned this Nov 19, 2025
@deepakn94 deepakn94 force-pushed the dnarayanan/latent_moe branch from d088236 to b0a2d8c Compare November 19, 2025 02:32
@yanring
Copy link
Copy Markdown
Contributor

yanring commented Nov 24, 2025

Thanks for the work. Could you also add UT and integration tests covering this feature combined with EP/TP?

Copy link
Copy Markdown
Contributor

@yanring yanring left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, left few comments

16 SMs can generally achieve good bandwidth."""

moe_latent_size: Optional[int] = None
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on it a bit here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, do you have reference for latent MoEs?

dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)

if self.config.moe_latent_size and mlp_bias is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document the change here to explain the bias handling update.


# Initialize latent projections
if self.config.moe_latent_size:
assert HAVE_TE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps assert HAVE_TE, "TransformerEngine is required for MoE latent projections."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could set skip_bias_add=False so that any necessary bias addition is handled internally by TELinear.

Comment on lines +306 to +311
# Project the hidden_states from hidden dimension down to latent dimenion.
if self.config.moe_latent_size:
assert (
not self.shared_expert_overlap
), "Shared expert overlap not supported when MoE latent projections are used."
hidden_states, _ = self.fc1_latent_proj(hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this projection needs to happen before we call self.token_dispatcher.dispatch_preprocess (in router_and_preprocess). Otherwise the hidden_shape in the token dispatcher gets set to the original hidden size instead of the latent size. This will result in a shape error when trying to apply fc2_latent_proj

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.

…he MoE routing happens in hidden dimension and correct tensor shape is captured in token dispatcher
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Dec 5, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@deepakn94
Copy link
Copy Markdown
Contributor Author

/ok to test dd645a8

Signed-off-by: Deepak Narayanan <dnarayanan@nvidia.com>
@deepakn94
Copy link
Copy Markdown
Contributor Author

/ok to test 80f1d5f

Signed-off-by: Deepak Narayanan <dnarayanan@nvidia.com>
@deepakn94
Copy link
Copy Markdown
Contributor Author

/ok to test 4257d6a

@deepakn94
Copy link
Copy Markdown
Contributor Author

/ok to test 44b708d

@ericharper
Copy link
Copy Markdown
Contributor

/ok to test ec58811

@deepakn94
Copy link
Copy Markdown
Contributor Author

/ok to test ec58811

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants