Skip to content

[Kernel] FlashInfer: switch allreduce fusion to unified API#33985

Merged
ProExpertProg merged 3 commits intovllm-project:mainfrom
mmangkad:update-vllm-flashinfer-allreduce
Feb 9, 2026
Merged

[Kernel] FlashInfer: switch allreduce fusion to unified API#33985
ProExpertProg merged 3 commits intovllm-project:mainfrom
mmangkad:update-vllm-flashinfer-allreduce

Conversation

@mmangkad
Copy link
Contributor

@mmangkad mmangkad commented Feb 6, 2026

Purpose

  • Migrate vLLM’s FlashInfer allreduce fusion to the unified flashinfer.comm.allreduce_fusion API and workspace creation.
  • Update the fused collective benchmark and fusion test to the new API and workspace lifecycle.

Dependency note: Requires flashinfer-python >= 0.6.3 (unified allreduce API).

Test Plan

pytest tests/compile/distributed/test_fusion_all_reduce.py

Test Result

========================================================================== test session starts ===========================================================================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /vllm
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 8 items

tests/compile/distributed/test_fusion_all_reduce.py ........                                                                                                       [100%]

============================================================================ warnings summary ============================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../usr/local/lib/python3.12/site-packages/astor/op_util.py:92
  /usr/local/lib/python3.12/site-packages/astor/op_util.py:92: DeprecationWarning: ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead
    precedence_data = dict((getattr(ast, x, None), z) for x, y, z in op_data)

tests/utils.py:1427
  /vllm/tests/utils.py:1427: PytestCollectionWarning: cannot collect test class 'TestFP8Layer' because it has a __init__ constructor (from: tests/compile/distributed/test_fusion_all_reduce.py)
    class TestFP8Layer(torch.nn.Module):

tests/compile/backend.py:40
  /vllm/tests/compile/backend.py:40: PytestCollectionWarning: cannot collect test class 'TestBackend' because it has a __init__ constructor (from: tests/compile/distributed/test_fusion_all_reduce.py)
    class TestBackend:

tests/compile/distributed/test_fusion_all_reduce.py:40
  /vllm/tests/compile/distributed/test_fusion_all_reduce.py:40: PytestCollectionWarning: cannot collect test class 'TestAllReduceRMSNormModel' because it has a __init__ constructor (from: tests/compile/distributed/test_fusion_all_reduce.py)
    class TestAllReduceRMSNormModel(torch.nn.Module):

tests/compile/distributed/test_fusion_all_reduce.py:77
  /vllm/tests/compile/distributed/test_fusion_all_reduce.py:77: PytestCollectionWarning: cannot collect test class 'TestAllReduceRMSNormStaticQuantFP8Model' because it has a __init__ constructor (from: tests/compile/distributed/test_fusion_all_reduce.py)
    class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):

tests/compile/distributed/test_fusion_all_reduce.py:128
  /vllm/tests/compile/distributed/test_fusion_all_reduce.py:128: PytestCollectionWarning: cannot collect test class 'TestAllReduceFusedAddRMSNormStaticQuantFP4Model' because it has a __init__ constructor (from: tests/compile/distributed/test_fusion_all_reduce.py)
    class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):

tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[True-dtype0-64-8-8-TestAllReduceRMSNormModel-False]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[True-dtype0-64-8-8-TestAllReduceRMSNormStaticQuantFP8Model-True]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[True-dtype0-64-8-8-TestAllReduceRMSNormStaticQuantFP8Model-False]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[True-dtype0-64-8-8-TestAllReduceFusedAddRMSNormStaticQuantFP4Model-False]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[False-dtype0-64-8-8-TestAllReduceRMSNormModel-False]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[False-dtype0-64-8-8-TestAllReduceRMSNormStaticQuantFP8Model-True]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[False-dtype0-64-8-8-TestAllReduceRMSNormStaticQuantFP8Model-False]
tests/compile/distributed/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[False-dtype0-64-8-8-TestAllReduceFusedAddRMSNormStaticQuantFP4Model-False]
  /vllm/tests/utils.py:1002: DeprecationWarning: This process (pid=458) is multi-threaded, use of fork() may lead to deadlocks in the child.
    pid = os.fork()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================== 8 passed, 16 warnings in 427.25s (0:07:07) ===============================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
@mergify mergify bot added the performance Performance-related issues label Feb 6, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request migrates the FlashInfer allreduce fusion to the new unified API. The changes span across the benchmark, tests, and the core compilation fusion logic. The migration correctly adopts the new allreduce_fusion function and the object-oriented workspace management. Overall, the changes are well-implemented. I've found a critical issue in the benchmark file where a function returns an incorrect number of values, which would lead to a runtime error under certain conditions. A code suggestion has been provided to fix this.

Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
@mmangkad
Copy link
Contributor Author

mmangkad commented Feb 8, 2026

cc @ProExpertProg could you run CI for this, since the FI version has already been bumped to 0.6.3?

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 8, 2026
Copy link
Contributor

@ilmarkov ilmarkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could you also run gsm8k lm-eval for some TP deployment with allreduce fusion enabled to validate?

@mmangkad
Copy link
Contributor Author

mmangkad commented Feb 9, 2026

LGTM. Could you also run gsm8k lm-eval for some TP deployment with allreduce fusion enabled to validate?

gpt-oss-120b low-reasoning run with tp=2 on gqpa:

Writing report to /tmp/gpqa_openai__gpt-oss-120b-low_temp1.0_20260209_144255.html
{'chars': np.float64(182.4324494949495), 'chars:std': np.float64(353.92776681661746), 'score': np.float64(0.6609848484848485), 'score:std': np.float64(0.47337498725461863)}
Writing results to /tmp/gpqa_openai__gpt-oss-120b-low_temp1.0_20260209_144255.json
Writing all results to /tmp/gpqa_openai__gpt-oss-120b-low_temp1.0_20260209_144255_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-120b-low_temp1.0_20260209_144255', 'metric': 0.6609848484848485}]

@ProExpertProg ProExpertProg enabled auto-merge (squash) February 9, 2026 15:02
@ProExpertProg ProExpertProg merged commit d4f123c into vllm-project:main Feb 9, 2026
55 checks passed
@mmangkad mmangkad deleted the update-vllm-flashinfer-allreduce branch February 9, 2026 15:44
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…ject#33985)

Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…ject#33985)

Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants