Fix: Change logic in allreduce.py to be consistent#3077
Fix: Change logic in allreduce.py to be consistent#3077askliar wants to merge 6 commits intoflashinfer-ai:mainfrom
allreduce.py to be consistent#3077Conversation
Removed scheduled trigger for issue comment workflow.
📝 WalkthroughWalkthroughMultiple function signatures across the codebase receive explicit return type annotations ( Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ❌ 3❌ Failed checks (3 warnings)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/comm/allreduce.py (2)
452-484:⚠️ Potential issue | 🟠 MajorAdd
@backend_requirementonallreduce_fusion.This high-level API has architecture/backend-specific constraints (e.g., pattern/backend support matrix) but is only decorated with
@flashinfer_api.As per coding guidelines, "Use
@backend_requirementdecorator on APIs with architecture-specific requirements to track supported compute capabilities".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/allreduce.py` around lines 452 - 484, The allreduce_fusion API is missing the `@backend_requirement` decorator and should declare its architecture/backend constraints; add `@backend_requirement`(...) above the allreduce_fusion definition (keeping the existing `@flashinfer_api`) with the appropriate supported backends/patterns, and import backend_requirement where decorators are defined; ensure the decorator is applied to the allreduce_fusion function (and not replacing `@flashinfer_api`) so the function signature and behavior remain unchanged while backend requirements are recorded.
710-741:⚠️ Potential issue | 🟠 MajorFallback return order is currently ineffective due unconditional tensor allocation.
Line 710 and Line 713 always materialize
norm_out/residual_out, so Lines 736-741 are effectively unreachable and the function always returnsnorm_out. This also forces extra allocations even when callers only want other outputs.Suggested fix (preserve fallback behavior while keeping at least one output)
- if norm_out is None: - norm_out = torch.empty_like(residual_in) - if residual_out is None: - residual_out = torch.empty_like(residual_in) + # Keep outputs optional so return fallback order is meaningful. + # Ensure at least one output exists if caller didn't request any. + if norm_out is None and quant_out is None and residual_out is None: + norm_out = torch.empty_like(residual_in)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/allreduce.py` around lines 710 - 741, The code currently always allocates norm_out and residual_out before calling trtllm_moe_finalize_allreduce_fusion, making the fallback returns ineffective; change this so you only allocate a default output tensor when absolutely needed. Before calling trtllm_moe_finalize_allreduce_fusion inspect the requested outputs (norm_out, quant_out, residual_out) and if all are None allocate a single temp = torch.empty_like(residual_in) and assign it to the highest-priority output you intend to return (e.g., norm_out) so the call has at least one output buffer; otherwise do not allocate the unused outputs. Update the call site of trtllm_moe_finalize_allreduce_fusion to pass the original or the single allocated buffer and leave the post-call return logic unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/comm/allreduce.py`:
- Around line 452-484: The allreduce_fusion API is missing the
`@backend_requirement` decorator and should declare its architecture/backend
constraints; add `@backend_requirement`(...) above the allreduce_fusion definition
(keeping the existing `@flashinfer_api`) with the appropriate supported
backends/patterns, and import backend_requirement where decorators are defined;
ensure the decorator is applied to the allreduce_fusion function (and not
replacing `@flashinfer_api`) so the function signature and behavior remain
unchanged while backend requirements are recorded.
- Around line 710-741: The code currently always allocates norm_out and
residual_out before calling trtllm_moe_finalize_allreduce_fusion, making the
fallback returns ineffective; change this so you only allocate a default output
tensor when absolutely needed. Before calling
trtllm_moe_finalize_allreduce_fusion inspect the requested outputs (norm_out,
quant_out, residual_out) and if all are None allocate a single temp =
torch.empty_like(residual_in) and assign it to the highest-priority output you
intend to return (e.g., norm_out) so the call has at least one output buffer;
otherwise do not allocate the unused outputs. Update the call site of
trtllm_moe_finalize_allreduce_fusion to pass the original or the single
allocated buffer and leave the post-call return logic unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 11f4f6e2-af0c-4a4a-b8eb-626926babaf1
📒 Files selected for processing (4)
flashinfer/aot.pyflashinfer/autotuner.pyflashinfer/comm/allreduce.pyflashinfer/jit/core.py
There was a problem hiding this comment.
Code Review
This pull request introduces type annotations across several modules and updates the allreduce_fusion function in allreduce.py to support a routed_scaling_factor. The return logic in allreduce_fusion was also modified to handle multiple potential output tensors. A review comment suggests simplifying the new conditional return block in allreduce_fusion by iterating through the possible outputs to return the first non-null value.
| if norm_out is not None: | ||
| return norm_out | ||
| elif quant_out is not None: | ||
| return quant_out | ||
| elif residual_out is not None: | ||
| return residual_out | ||
| else: | ||
| return input |
Summary by CodeRabbit
New Features
routed_scaling_factorparameter to allreduce fusion function for enhanced flexibility.Refactor