-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[torch.compile] Enable attention and allreduce fusion without custom ops enabled #24604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] Enable attention and allreduce fusion without custom ops enabled #24604
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
b374514 to
4a44829
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
42f2231 to
a8c9181
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
1e9326c to
e3d0c83
Compare
e3d0c83 to
9151d01
Compare
…g utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: luka <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
9151d01 to
da3cb54
Compare
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: luka <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
|
The reason this is needed is it lets us do fusion without having to enable custom ops (-O.custom_ops=["+quant_fp8"]). Enabling custom ops leads to lost performance, as demonstrated in the PR description. That's because there are 4 quant ops per layer, one per matmul, and I agree this is a somewhat fragile approach. I would be happy to work on a "lowering" approach where we preserve the high-level structure of ops until later. The downside would be that it would require more work (I think), and we might lose access to optimizations that currently happen before our passes . But I think it wouldn't hurt Inductor in general to have a more explicit sense of converting between higher-level and lower-level representations (or we just move where our custom passes happen). We can tie this work into the "autotuning custom op implementations" like done in pytorch/pytorch#164212. |
|
As discussed offline, we are going to proceed by merging this PR. After PTC, we will move our custom op matching passes to |
|
view/slice noop eliminations were upstreamed to PyTorch so I'm wondering if this is sufficient pytorch/pytorch#151095 pytorch/pytorch#151175 |
…hing-2 Signed-off-by: Luka Govedič <[email protected]>
|
@BoyuanFeng wouldn't that run after |
Signed-off-by: Luka Govedič <[email protected]>
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: luka <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Alberto Perdomo <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: luka <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: luka <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
…ops enabled (vllm-project#24604) Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]>
Purpose
This PR enables matching the torch implementations of custom ops QuantFP8 and RMSNorm. On
main, fusion currently requires enabling custom ops, but they are slower than their torch counterparts, so the benefit of custom fusion passes is reduced.We add a bunch of "matcher util" objects which can be called in patterns and get traced to the same fx nodes as the custom op they correspond to in both enabled and disabled form automatically.
This PR also adds additional debugging utilities and adds E2E fusion tests to verify fusions happen in models end-to-end instead of just in unit tests.
Test Plan
Unit tests, added more fusion E2E tests.
Test Result
Tests all pass
Performance numbers
Below are B200 numbers (with flashinfer) from
vllm bench serveon the following serve command:We test the following regimes with corresponding additional arguments:
none:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":false,"enable_noop":true}none_fusion_attention:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":true,"enable_noop":true}none_fusion_attention_allreduce:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":true,"enable_attn_fusion":true,"enable_noop":true}rms_quant:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":false,"enable_noop":true}rms_quant_fusion_attention:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":true,"enable_noop":true}rms_quant_fusion_attention_allreduce:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":true,"enable_attn_fusion":true,"enable_noop":true}2 (
none_fusion_attention) and 3 (none_fusion_attention_allreduce) are newly possible with this PR. On main, results are similar except those two are worse as fusion cannot happen without custom ops enabled.redhatai/meta-llama-3.1-70B-Instruct-FP8 (TP=1):Past QPS=10 the server is overloaded so the latency spikes and becomes much more variable. Also note that allreduce fusion is a noop for tp=1.
📊 TTFT Median (ms)
📊 TPOT Median (ms)
📊 ITL Median (ms)
redhatai/meta-llama-3.1-70B-Instruct-FP8 (TP=4):Note that allreduce fusion reduces TPOT at low QP but increases it at high QPS and increases TTFT across the board, this will be addressed in #24248 and #24252.
📊 TTFT Median (ms)
📊 TPOT Median (ms)
📊 ITL Median (ms)