Skip to content

Fix: Change logic in allreduce.py to be consistent#3077

Open
askliar wants to merge 6 commits intoflashinfer-ai:mainfrom
askliar:fix/make_precommit_pass
Open

Fix: Change logic in allreduce.py to be consistent#3077
askliar wants to merge 6 commits intoflashinfer-ai:mainfrom
askliar:fix/make_precommit_pass

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented Apr 15, 2026

Summary by CodeRabbit

  • New Features

    • Added optional routed_scaling_factor parameter to allreduce fusion function for enhanced flexibility.
  • Refactor

    • Enhanced type annotations across the codebase for improved code quality and maintainability.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 15, 2026

📝 Walkthrough

Walkthrough

Multiple function signatures across the codebase receive explicit return type annotations (-> None), and the allreduce_fusion() API gains an optional routed_scaling_factor parameter to support MOE fusion operations with configurable scaling behavior.

Changes

Cohort / File(s) Summary
Type Annotations
flashinfer/aot.py, flashinfer/jit/core.py
Added -> None return type annotation to main() and JitSpecRegistry.__init__() methods.
AutoTuner Type Hints
flashinfer/autotuner.py
Added explicit type annotations to __init__ parameters (warmup: int, repeat: int, stream_delay_micro_secs: int) and -> None return type. Typed profiling_cache attribute as Dict[Tuple[Any, ...], Tuple[int, Any, OptimizationProfile]].
MOE Allreduce Fusion Enhancement
flashinfer/comm/allreduce.py
Added optional routed_scaling_factor: Optional[float] = None parameter to allreduce_fusion(). Updated MOE finalize fusion path to pass routed_scaling_factor instead of None and changed return logic to conditionally return norm_out, quant_out, residual_out, or input based on availability.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

Suggested labels

run-ci, op: comm

Suggested reviewers

  • aleozlx
  • samuellees
  • bkryu
  • yzh119
  • yyihuang
  • jimmyzho

Poem

🐰✨ Types align, annotations shine so bright,
MOE scaling factors take their rightful flight,
Fusion paths refined with clarity profound,
Return flows dancing as the best is found! 🎯

🚥 Pre-merge checks | ❌ 3

❌ Failed checks (3 warnings)

Check name Status Explanation Resolution
Title check ⚠️ Warning The PR title focuses on 'allreduce.py' logic consistency, but the changeset also includes type annotation updates to three other files (aot.py, autotuner.py, jit/core.py), which represent the majority of changes. Update the title to reflect all changes, such as 'chore: Add type annotations and fix allreduce logic consistency' to accurately represent the broader scope of type annotation updates across multiple files.
Description check ⚠️ Warning No description was provided by the author; the PR description is completely empty, failing to meet the repository's required template structure and missing context about the changes. Add a comprehensive description following the template with sections for Description, Related Issues, and confirmation of Pre-commit Checks and Tests completion.
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/comm/allreduce.py (2)

452-484: ⚠️ Potential issue | 🟠 Major

Add @backend_requirement on allreduce_fusion.

This high-level API has architecture/backend-specific constraints (e.g., pattern/backend support matrix) but is only decorated with @flashinfer_api.

As per coding guidelines, "Use @backend_requirement decorator on APIs with architecture-specific requirements to track supported compute capabilities".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/allreduce.py` around lines 452 - 484, The allreduce_fusion
API is missing the `@backend_requirement` decorator and should declare its
architecture/backend constraints; add `@backend_requirement`(...) above the
allreduce_fusion definition (keeping the existing `@flashinfer_api`) with the
appropriate supported backends/patterns, and import backend_requirement where
decorators are defined; ensure the decorator is applied to the allreduce_fusion
function (and not replacing `@flashinfer_api`) so the function signature and
behavior remain unchanged while backend requirements are recorded.

710-741: ⚠️ Potential issue | 🟠 Major

Fallback return order is currently ineffective due unconditional tensor allocation.

Line 710 and Line 713 always materialize norm_out/residual_out, so Lines 736-741 are effectively unreachable and the function always returns norm_out. This also forces extra allocations even when callers only want other outputs.

Suggested fix (preserve fallback behavior while keeping at least one output)
-            if norm_out is None:
-                norm_out = torch.empty_like(residual_in)
-            if residual_out is None:
-                residual_out = torch.empty_like(residual_in)
+            # Keep outputs optional so return fallback order is meaningful.
+            # Ensure at least one output exists if caller didn't request any.
+            if norm_out is None and quant_out is None and residual_out is None:
+                norm_out = torch.empty_like(residual_in)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/allreduce.py` around lines 710 - 741, The code currently
always allocates norm_out and residual_out before calling
trtllm_moe_finalize_allreduce_fusion, making the fallback returns ineffective;
change this so you only allocate a default output tensor when absolutely needed.
Before calling trtllm_moe_finalize_allreduce_fusion inspect the requested
outputs (norm_out, quant_out, residual_out) and if all are None allocate a
single temp = torch.empty_like(residual_in) and assign it to the
highest-priority output you intend to return (e.g., norm_out) so the call has at
least one output buffer; otherwise do not allocate the unused outputs. Update
the call site of trtllm_moe_finalize_allreduce_fusion to pass the original or
the single allocated buffer and leave the post-call return logic unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/comm/allreduce.py`:
- Around line 452-484: The allreduce_fusion API is missing the
`@backend_requirement` decorator and should declare its architecture/backend
constraints; add `@backend_requirement`(...) above the allreduce_fusion definition
(keeping the existing `@flashinfer_api`) with the appropriate supported
backends/patterns, and import backend_requirement where decorators are defined;
ensure the decorator is applied to the allreduce_fusion function (and not
replacing `@flashinfer_api`) so the function signature and behavior remain
unchanged while backend requirements are recorded.
- Around line 710-741: The code currently always allocates norm_out and
residual_out before calling trtllm_moe_finalize_allreduce_fusion, making the
fallback returns ineffective; change this so you only allocate a default output
tensor when absolutely needed. Before calling
trtllm_moe_finalize_allreduce_fusion inspect the requested outputs (norm_out,
quant_out, residual_out) and if all are None allocate a single temp =
torch.empty_like(residual_in) and assign it to the highest-priority output you
intend to return (e.g., norm_out) so the call has at least one output buffer;
otherwise do not allocate the unused outputs. Update the call site of
trtllm_moe_finalize_allreduce_fusion to pass the original or the single
allocated buffer and leave the post-call return logic unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 11f4f6e2-af0c-4a4a-b8eb-626926babaf1

📥 Commits

Reviewing files that changed from the base of the PR and between bf9b1da and 88f5a62.

📒 Files selected for processing (4)
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • flashinfer/comm/allreduce.py
  • flashinfer/jit/core.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces type annotations across several modules and updates the allreduce_fusion function in allreduce.py to support a routed_scaling_factor. The return logic in allreduce_fusion was also modified to handle multiple potential output tensors. A review comment suggests simplifying the new conditional return block in allreduce_fusion by iterating through the possible outputs to return the first non-null value.

Comment on lines +734 to +741
if norm_out is not None:
return norm_out
elif quant_out is not None:
return quant_out
elif residual_out is not None:
return residual_out
else:
return input
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if-elif-else chain for returning outputs is redundant and can be simplified using a list or tuple of potential outputs to return the first non-None value.

for out in [norm_out, quant_out, residual_out]:
                if out is not None:
                    return out
            return input

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants