Conversation
There was a problem hiding this comment.
Pull request overview
This PR patches the grouped top-k implementation for MoE (Mixture of Experts) operations in vLLM on HPU, addressing dtype conversion behavior based on whether grouped top-k is enabled.
- Adds conditional dtype conversion logic based on
use_grouped_topkflag - Implements a patched
grouped_topkfunction with batch invariance support and e_score_correction_bias handling - Applies the patch to the vLLM library's fused_moe layer module
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| vllm_gaudi/ops/hpu_fused_moe.py | Adds conditional dtype conversion, implements patched_grouped_topk function, and applies the grouped_topk patch to vllm module |
| vllm_gaudi/ops/hpu_fp8.py | Adds conditional dtype conversion for FP8 operations based on use_grouped_topk flag |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not layer.use_grouped_topk: | ||
| topk_ids = topk_ids.to(torch.int64) | ||
| topk_weights = topk_weights.to(x.dtype) |
There was a problem hiding this comment.
The dtype conversions for topk_ids and topk_weights are now duplicated - they appear both before line 67 (lines 63-64) and within this conditional block (lines 68-69). When use_grouped_topk is False, these conversions happen twice unnecessarily. Consider moving the earlier conversions (lines 63-64) into an else block, or removing the duplicate logic.
| topk_weights = topk_weights.view(*x.shape[:-1], -1) | ||
| if not layer.use_grouped_topk: | ||
| topk_ids = topk_ids.to(torch.int64) | ||
| topk_weights = topk_weights.to(x.dtype) |
There was a problem hiding this comment.
The dtype conversions for topk_ids and topk_weights are duplicated - they appear both before line 163 (lines 159-160) and within this conditional block (lines 164-165). When use_grouped_topk is False, these conversions happen twice unnecessarily. Consider moving the earlier conversions (lines 159-160) into an else block, or removing the duplicate logic.
| topk_weights = topk_weights.to(x.dtype) |
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
✅ CI PassedAll checks passed successfully against the following vllm commit: |
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
No description provided.