[NemotronH] Do not force router to run in fp32#34582
[NemotronH] Do not force router to run in fp32#34582vllm-bot merged 4 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
2fb9690 to
48ab68e
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable performance optimization by removing the forced casting of MoE router logits to float32. The changes in nemotron_h.py correctly implement this, and the special case for DeepSeekV3 is properly handled in flashinfer_trtllm_moe.py. I've found one minor issue: a leftover debug print statement that should be removed.
| from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe | ||
|
|
||
| # The DeepSeekV3 routing method requires float32 router logits. | ||
| print(routing_method_type) |
| routing_logits = routing_logits.to(torch.float32) | ||
|
|
||
| if routing_bias is not None: | ||
| routing_bias = routing_bias.to(hidden_states.dtype) |
There was a problem hiding this comment.
I remember something about it being important that the bias is in FP32.. I understand in this case we first cast logits to FP32 (since we're using DS routing) so the bias is actually in FP32, but doesn't it make more sense to cast logits to bias dtype instead of the other way around?
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
mgoin
left a comment
There was a problem hiding this comment.
LGTM, nice find! We maybe should add an assert on the bias, but this roughly matches other trtllm moe impls. Do you have any perf result? You mentioned this takes 40% of time, which doesn't make sense to me
Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
…34582)" (vllm-project#34808) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
…34582)" (vllm-project#34808) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
…34582)" (vllm-project#34808) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
…34582)" (vllm-project#34808) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…34582)" (vllm-project#34808) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
…34582)" (vllm-project#34808) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>

Purpose
Current code forces the MoE router computation to FP32, even though checkpoints have it in bfloat16. This takes up about 40% of the forward pass, under normal workloads, and does not provide an accuracy boost.
This PR removes this limitations.
Test Plan
No additional tests, all tests pass, accuracy does not degrade
Test Result
All tests pass, accuracy did not degrade
Running GSM8K, got the following results.
PR:
Main:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.