Changes to support latent MoEs#2296
Conversation
d088236 to
b0a2d8c
Compare
|
Thanks for the work. Could you also add UT and integration tests covering this feature combined with EP/TP? |
yanring
left a comment
There was a problem hiding this comment.
Overall LGTM, left few comments
| 16 SMs can generally achieve good bandwidth.""" | ||
|
|
||
| moe_latent_size: Optional[int] = None | ||
| """Latent projection dimension for MoE. If None, MoE latent projections are not used.""" |
There was a problem hiding this comment.
Could you elaborate on it a bit here?
There was a problem hiding this comment.
Yeah, do you have reference for latent MoEs?
| dispatched_input, probs = self.dispatch(hidden_states, probs) | ||
| output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) | ||
|
|
||
| if self.config.moe_latent_size and mlp_bias is not None: |
There was a problem hiding this comment.
Please document the change here to explain the bias handling update.
|
|
||
| # Initialize latent projections | ||
| if self.config.moe_latent_size: | ||
| assert HAVE_TE |
There was a problem hiding this comment.
Perhaps assert HAVE_TE, "TransformerEngine is required for MoE latent projections."
| config=self.config, | ||
| init_method=self.config.output_layer_init_method, | ||
| bias=self.config.add_bias_linear, | ||
| skip_bias_add=True, |
There was a problem hiding this comment.
Perhaps we could set skip_bias_add=False so that any necessary bias addition is handled internally by TELinear.
| # Project the hidden_states from hidden dimension down to latent dimenion. | ||
| if self.config.moe_latent_size: | ||
| assert ( | ||
| not self.shared_expert_overlap | ||
| ), "Shared expert overlap not supported when MoE latent projections are used." | ||
| hidden_states, _ = self.fc1_latent_proj(hidden_states) |
There was a problem hiding this comment.
I believe this projection needs to happen before we call self.token_dispatcher.dispatch_preprocess (in router_and_preprocess). Otherwise the hidden_shape in the token dispatcher gets set to the original hidden size instead of the latent size. This will result in a shape error when trying to apply fc2_latent_proj
…he MoE routing happens in hidden dimension and correct tensor shape is captured in token dispatcher
|
/ok to test dd645a8 |
Signed-off-by: Deepak Narayanan <dnarayanan@nvidia.com>
|
/ok to test 80f1d5f |
Signed-off-by: Deepak Narayanan <dnarayanan@nvidia.com>
|
/ok to test 4257d6a |
|
/ok to test 44b708d |
|
/ok to test ec58811 |
|
/ok to test ec58811 |
No description provided.