[Kernel] FlashInfer: switch allreduce fusion to unified API#33985
[Kernel] FlashInfer: switch allreduce fusion to unified API#33985ProExpertProg merged 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request migrates the FlashInfer allreduce fusion to the new unified API. The changes span across the benchmark, tests, and the core compilation fusion logic. The migration correctly adopts the new allreduce_fusion function and the object-oriented workspace management. Overall, the changes are well-implemented. I've found a critical issue in the benchmark file where a function returns an incorrect number of values, which would lead to a runtime error under certain conditions. A code suggestion has been provided to fix this.
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
|
cc @ProExpertProg could you run CI for this, since the FI version has already been bumped to 0.6.3? |
ilmarkov
left a comment
There was a problem hiding this comment.
LGTM. Could you also run gsm8k lm-eval for some TP deployment with allreduce fusion enabled to validate?
gpt-oss-120b low-reasoning run with tp=2 on gqpa: |
…ject#33985) Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
…ject#33985) Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
Purpose
flashinfer.comm.allreduce_fusionAPI and workspace creation.Test Plan
pytest tests/compile/distributed/test_fusion_all_reduce.pyTest Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.