Skip to content

[DO NOT MERGE][vLLM IR] 2/N batch-invariant-aware dispatching and rms_norm#36816

Open
ProExpertProg wants to merge 9 commits intomainfrom
luka/vllm-ir/rms-norm-batch-invariant
Open

[DO NOT MERGE][vLLM IR] 2/N batch-invariant-aware dispatching and rms_norm#36816
ProExpertProg wants to merge 9 commits intomainfrom
luka/vllm-ir/rms-norm-batch-invariant

Conversation

@ProExpertProg
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg commented Mar 11, 2026

Purpose

This PR adds batch-invariant-aware kernel dispatching infrastructure to vLLM IR and plugs the batch-invariant Triton kernel as an implementation of the rms_norm op.

Key changes:

  • Extended IrOp to support batch_invariant flag on implementations, allowing kernel selection based on VLLM_BATCH_INVARIANT mode
  • Registered batch-invariant Triton kernel for rms_norm (vllm.kernels.triton.layernorm_batch_invariant)
  • Updated lowering pass to filter implementations by batch-invariance requirements
  • Added platform capability detection for batch-invariant kernel support (SM 9.0+)

How it works:

  • IR ops are batch-invariant by default, ops with reductions (has_reduction, e.g. rms_norm) are not batch-invariant by default
  • Native implementations are always batch-invariant
  • Implementations explicitly opt-in via batch_invariant=True parameter
  • Kernel selection automatically chooses batch-invariant implementations when VLLM_BATCH_INVARIANT=1

Test Plan

# Unit tests
pytest tests/ir/test_op.py -v -k batch_invariant                                                                                                                                                                             
pytest tests/kernels/ir/test_layernorm.py -v
pytest tests/compile/passes/ir/test_lowering.py -v

# Batch invariance tests
pytest -s -v tests/v1/determinism/test_batch_invariance.py

# E2E: lm_eval, vllm bench latency

Test Result

# H100
$ pytest tests/v1/determinism/test_batch_invariance.py 
tests/v1/determinism/test_batch_invariance.py ............. [100%]
============ 13 passed, 34 warnings in 488.50s (0:08:08) ============
$ pytest tests/ir/ tests/kernels/ir/ tests/compile/passes/ir/ 
tests/ir/test_op.py ....................... [  5%]
tests/kernels/ir/test_layernorm.py .............................................................................................................................................................................................. [ 46%]
.................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................................... [ 95%]
.................. [ 99%]
tests/compile/passes/ir/test_lowering.py ... [100%]
============ 351 passed, 108 skipped, 19 warnings in 143.29s (0:02:23) ============

# B200
$ pytest tests/v1/determinism/test_batch_invariance.py 
tests/v1/determinism/test_batch_invariance.py ............. [100%]
============ 13 passed, 34 warnings in 559.89s (0:09:19) ============

$ pytest tests/ir/ tests/kernels/ir/ tests/compile/passes/ir/ 
tests/ir/test_op.py ....................... [  5%]
tests/kernels/ir/test_layernorm.py ...................................................................................................................................................................... [ 41%]
.........................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 84%]
ssssssssssss...................................................... [ 99%] 
tests/compile/passes/ir/test_lowering.py ... [100%]
============ 351 passed, 108 skipped, 19 warnings in 161.71s (0:02:41) ============

lm_eval

B200

main
$ VLLM_BATCH_INVARIANT=0 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8469 ± 0.0099
strict-match 5 exact_match 0.8886 ± 0.0087
$ VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8582 ± 0.0096
strict-match 5 exact_match 0.8976 ± 0.0083
$ VLLM_BATCH_INVARIANT=0 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN --enforce-eager
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8537 ± 0.0097
strict-match 5 exact_match 0.8969 ± 0.0084
$ VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN --enforce-eager
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8506 ± 0.0098
strict-match 5 exact_match 0.8969 ± 0.0084
PR
$ VLLM_BATCH_INVARIANT=0 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8476 ± 0.0099
strict-match 5 exact_match 0.8916 ± 0.0086
$ VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8582 ± 0.0096
strict-match 5 exact_match 0.8976 ± 0.0083
$ VLLM_BATCH_INVARIANT=0 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN --enforce-eager
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8529 ± 0.0098
strict-match 5 exact_match 0.8984 ± 0.0083
$ VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-30B-A3B --attention-backend=TRITON_ATTN --enforce-eager
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8506 ± 0.0098
strict-match 5 exact_match 0.8969 ± 0.0084

Latency

B200

All were run with vllm bench latency --attention-backend=TRITON_ATTN (TP=1).

Configuration main [s] PR [s]
Qwen/Qwen3-0.6B
VLLM_BATCH_INVARIANT=0 0.288 0.278
VLLM_BATCH_INVARIANT=0 --enforce-eager 2.157 2.150
VLLM_BATCH_INVARIANT=1 0.493 0.490
VLLM_BATCH_INVARIANT=1 --enforce-eager 3.057 4.602
nvidia/Llama-3.3-70B-Instruct-NVFP4
VLLM_BATCH_INVARIANT=0 1.612 1.620
VLLM_BATCH_INVARIANT=1 11.35 11.34

H100

TBD


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the concept of batch-invariance for IR ops, which is a good step towards more efficient dispatching. A new has_reduction flag is added to IrOp to control the default batch_invariant status of its implementations. However, I've found a critical issue in how the batch_invariant flag is set for native PyTorch implementations. It's hardcoded to True, which is incorrect for reduction operations and inconsistent with how other implementations are handled. This could lead to incorrect behavior and numerical results when the dispatcher is implemented. My review includes a suggestion to fix this inconsistency.

Comment thread vllm/ir/op.py
@ProExpertProg ProExpertProg added the vllm-ir vLLM IR: intermediate representation and kernel registration label Mar 11, 2026
@ProExpertProg ProExpertProg force-pushed the luka/vllm-ir/rms-norm-batch-invariant branch from 8041106 to d8fe95a Compare March 12, 2026 10:02
Comment thread vllm/ir/op.py Outdated
@ProExpertProg ProExpertProg force-pushed the luka/vllm-ir/rms-norm-batch-invariant branch from d8fe95a to 810b9f3 Compare March 12, 2026 19:42
@ProExpertProg ProExpertProg force-pushed the luka/vllm-ir/rms-norm-batch-invariant branch from b39721e to faa028b Compare March 20, 2026 12:48
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 20, 2026

Documentation preview: https://vllm--36816.org.readthedocs.build/en/36816/

@mergify mergify bot added documentation Improvements or additions to documentation ci/build deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models labels Mar 20, 2026
@mergify mergify bot added the nvidia label Mar 20, 2026
@mergify mergify bot added the rocm Related to AMD ROCm label Mar 20, 2026
@mergify mergify bot added the cpu Related to CPU backends label Mar 20, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 20, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 31, 2026

Documentation preview: https://vllm--36816.org.readthedocs.build/en/36816/

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 31, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 31, 2026

Hi @ProExpertProg, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 1, 2026

Documentation preview: https://vllm--36816.org.readthedocs.build/en/36816/

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

ProExpertProg and others added 5 commits April 8, 2026 20:55
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 15, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
# Conflicts:
#	.buildkite/test_areas/misc.yaml
#	tests/v1/determinism/test_batch_invariance.py

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! I am worried about we are making it very complicated here with batch invariance.

return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness now depends on ir_op_priority.set_priority() already being installed in the current forward context. Outside the normal engine path will silently fall back to native when no priority is set, so VLLM_BATCH_INVARIANT=1 no longer guarantees the batch-invariant RMSNorm path here.

@yewentao256
Copy link
Copy Markdown
Member

Could you also benchmark with vllm bench serve... to see the perf of token throughput, TTFT etc?

I am also interested the perf difference between normal path and IR path for batch invariance

@ProExpertProg
Copy link
Copy Markdown
Collaborator Author

Additional benchmarks comparing just triton_batch_invariant RMSNorm to native (this PR):

Qwen3-0.6B, B200

Seems like the Triton kernel is slightly faster than native.

$ VLLM_BATCH_INVARIANT=1 vllm bench latency --model=Qwen/Qwen3-0.6B --attention-backend=TRITON_ATTN --ir-op-priority.rms_norm=native

Avg latency: 0.46398694468662144 seconds
10% percentile latency: 0.46021976000629367 seconds
25% percentile latency: 0.4604079438140616 seconds
50% percentile latency: 0.46575065865181386 seconds
75% percentile latency: 0.4660729080205783 seconds
90% percentile latency: 0.46662389640696345 seconds
99% percentile latency: 0.46731352266855536 seconds

$ VLLM_BATCH_INVARIANT=1 vllm bench latency --model=Qwen/Qwen3-0.6B --attention-backend=TRITON_ATTN --ir-op-priority.rms_norm=triton_batch_invariant

Avg latency: 0.45794374315689007 seconds
10% percentile latency: 0.4542138213291764 seconds
25% percentile latency: 0.45449806610122323 seconds
50% percentile latency: 0.4595492200460285 seconds
75% percentile latency: 0.4604009126778692 seconds
90% percentile latency: 0.46168355625122787 seconds
99% percentile latency: 0.46257003766018895 seconds

RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8, B200

VLLM_BATCH_INVARIANT=1 vllm bench latency --model=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 --attention-backend=TRITON_ATTN --ir-op-priority.rms_norm=native -cc.pass_config.fuse_norm_quant=False

Avg latency: 0.39199300021864475 seconds
10% percentile latency: 0.3913078000303358 seconds
25% percentile latency: 0.3920521972468123 seconds
50% percentile latency: 0.39233378903009 seconds
75% percentile latency: 0.3927705123787746 seconds
90% percentile latency: 0.39286121167242527 seconds
99% percentile latency: 0.39316565855406227 seconds

VLLM_BATCH_INVARIANT=1 vllm bench latency --model=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 --attention-backend=TRITON_ATTN --ir-op-priority.rms_norm=triton_batch_invariant -cc.pass_config.fuse_norm_quant=False

Avg latency: 0.39097082077836 seconds
10% percentile latency: 0.386426914203912 seconds
25% percentile latency: 0.390549574396573 seconds
50% percentile latency: 0.3917238435242325 seconds
75% percentile latency: 0.3921746846754104 seconds
90% percentile latency: 0.39293897384777665 seconds
99% percentile latency: 0.39752276936545966 seconds




Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1 vllm-ir vLLM IR: intermediate representation and kernel registration

Projects

Status: Todo
Status: No status
Status: No status
Status: No status
Status: To Triage
Status: In review

Development

Successfully merging this pull request may close these issues.

3 participants