[DO NOT MERGE][vLLM IR] 2/N batch-invariant-aware dispatching and rms_norm#36816
[DO NOT MERGE][vLLM IR] 2/N batch-invariant-aware dispatching and rms_norm#36816ProExpertProg wants to merge 9 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the concept of batch-invariance for IR ops, which is a good step towards more efficient dispatching. A new has_reduction flag is added to IrOp to control the default batch_invariant status of its implementations. However, I've found a critical issue in how the batch_invariant flag is set for native PyTorch implementations. It's hardcoded to True, which is incorrect for reduction operations and inconsistent with how other implementations are handled. This could lead to incorrect behavior and numerical results when the dispatcher is implemented. My review includes a suggestion to fix this inconsistency.
8041106 to
d8fe95a
Compare
d8fe95a to
810b9f3
Compare
b39721e to
faa028b
Compare
|
Documentation preview: https://vllm--36816.org.readthedocs.build/en/36816/ |
|
Documentation preview: https://vllm--36816.org.readthedocs.build/en/36816/ |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Hi @ProExpertProg, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Documentation preview: https://vllm--36816.org.readthedocs.build/en/36816/ |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
# Conflicts: # .buildkite/test_areas/misc.yaml # tests/v1/determinism/test_batch_invariance.py Signed-off-by: Luka Govedič <lgovedic@redhat.com>
yewentao256
left a comment
There was a problem hiding this comment.
Thanks for the work! I am worried about we are making it very complicated here with batch invariance.
| return ir.ops.rms_norm( | ||
| x, self.weight.data, self.variance_epsilon, self.variance_size_override | ||
| ) | ||
|
|
There was a problem hiding this comment.
Correctness now depends on ir_op_priority.set_priority() already being installed in the current forward context. Outside the normal engine path will silently fall back to native when no priority is set, so VLLM_BATCH_INVARIANT=1 no longer guarantees the batch-invariant RMSNorm path here.
|
Could you also benchmark with I am also interested the perf difference between normal path and IR path for batch invariance |
|
Additional benchmarks comparing just
|
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Purpose
This PR adds batch-invariant-aware kernel dispatching infrastructure to vLLM IR and plugs the batch-invariant Triton kernel as an implementation of the
rms_normop.Key changes:
IrOpto supportbatch_invariantflag on implementations, allowing kernel selection based onVLLM_BATCH_INVARIANTmoderms_norm(vllm.kernels.triton.layernorm_batch_invariant)How it works:
has_reduction, e.g.rms_norm) are not batch-invariant by defaultbatch_invariant=TrueparameterVLLM_BATCH_INVARIANT=1Test Plan
Test Result
lm_eval
B200
main
PR
Latency
B200
All were run with
vllm bench latency --attention-backend=TRITON_ATTN(TP=1).H100
TBD
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.