Skip to content

[vLLM IR] 1/N Implement IR skeleton and rms_norm op#33825

Merged
ProExpertProg merged 56 commits intovllm-project:mainfrom
neuralmagic:luka/vllm-ir/rms-norm
Apr 1, 2026
Merged

[vLLM IR] 1/N Implement IR skeleton and rms_norm op#33825
ProExpertProg merged 56 commits intovllm-project:mainfrom
neuralmagic:luka/vllm-ir/rms-norm

Conversation

@ProExpertProg
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg commented Feb 4, 2026

Purpose

This PR implements the foundational infrastructure for vLLM IR (Intermediate Representation), a functional IR system for vLLM custom operations, starting with the rms_norm operation. This is the first of many PRs to addresses RFC #32358.

What is vLLM IR? vLLM IR is a functional intermediate representation that separates operation semantics from implementation and dispatching. It serves as a higher-level torch dialect with the following main benefits:

  • Streamlined compilation passes
  • Single source of truth dispatching
  • Simple and extensible op and implementation/kernel registration

This PR contains the following initial features of vLLM IR:

  • Op and kernel registration: Registration decorator (vllm.ir.register_op) returns an IrOp object which is a callable object containing op metadata and utilities. Impl registration decorator (IrOp.register_impl) returns an IrOpImpl object which contains implementation metadata and utilities.
  • Eager mode dispatching: Upon direct call, IrOp.dispatch dispatches the call to the selected implementation, according to the priority list and runtime support_args predicates.
  • rms_norm op & implementation registration: the op lives in vllm/ir/ops/layernorm.py. The implementations live in vllm/kernels/*.py - different files for different providers.
  • Lowering pass: vllm.compilation.passes.ir.lowering_pass.VllmIRLoweringPass runs at the end of post-grad custom post-passes and lowers vllm_ir torch ops into the selected implementation, according to the priority list and runtime support_args predicates (consistent with eager-mode dispatching).
  • MatcherRMSNorm replaced with torch.ops.vllm_ir.rms_norm: In custom compile passes, the fragile matching utility can be fully replaced by calling the vLLM IR op in pattern matcher patterns and replacements.
  • vllm-level IR op priority: Adds a IROpPriorityConfig in KernelConfig, including a top-level CLI flag. Each platform also defines its own default op priority, which is combined with any user-specified values. This priority is then passed down to the IR op priority at the start of every forward pass.

Kernel implementation providers:

  • vllm_c: C++/CUDA/HIP kernels (CUDA/ROCm platforms)
  • aiter: ROCm-optimized AITER kernels (ROCm only)
  • xpu_kernels: Intel XPU kernels (XPU platform)
  • oink: Custom oink op kernels

Other non-IR changes:

  • Added disable_log_dedup fixture for testing warning_once and info_once
  • Added print_graphs() helper to test backend
  • Disabled FlashInfer autotuning in fusion E2E tests for faster runtime

fused_add_rms_norm & batch invariant: This PR leaves the fused_add_rms_norm and batch invariant parts of the RMSNorm custom op intact. PRs 2/N (#36816) and 3/N (#36823) will address these two, after which we can migrate RMSNorm from CustomOp to PluggableLayer.

Remaining TODOs:

  • Finalize direct_dispatch implementation
  • Add IR op impl source to compile UUID
  • Figure out the autograd bug & resolution
  • E2E eval, bench, inspect output code

Test Plan

  • Extensive unit_tests in new files:
    • tests/ir/test_op.py - IR op registration, dispatching, priority system
    • tests/kernels/ir/test_layernorm.py - RMS norm implementation tests
    • tests/compile/passes/ir/test_lowering.py - Lowering pass tests
  • Existing compile & custom pass unit tests
  • E2E lm_eval, E2E perf, manually check Inductor graph matches, qwen0.6B --enforce-eager to measure worst-case dispatching overhead

Test Result

CI tests passing. Manually validated that Inductor output code is identical for Deepseek-V3.1: complex because it contains rms_norm on q as well as per-layer.

Qwen-0.6B latency configuration sweep on B200:

Command: vllm bench latency --model=Qwen/Qwen3-0.6B

Configuration \ Median latency main PR
--enforce-eager 1.037s 1.010s
-cc.cudagraph_mode=FULL_DECODE_ONLY -cc.mode=NONE 0.229s 0.229s
-cc.cudagraph_mode=FULL_DECODE_ONLY 0.210s 0.210s
-cc.custom_ops+=+rms_norm --ir-op-priority.rms_norm=vllm_c 0.214s 0.216s
(default) 0.197s 0.196s

lm_eval (B200)

main

$ vllm serve Qwen/Qwen3-30B-A3B
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8582 ± 0.0096
strict-match 5 exact_match 0.8984 ± 0.0083
$ vllm serve RedHatAI/Meta-Llama-3.1-70B-FP8 -tp=2
local-completions ({'pretrained': 'RedHatAI/Meta-Llama-3.1-70B-FP8', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8052 ± 0.0109
strict-match 5 exact_match 0.8052 ± 0.0109
$ vllm serve deepseek-ai/DeepSeek-V3.1 -dp 8 -ep
local-completions ({'pretrained': 'deepseek-ai/DeepSeek-V3.1', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9613 ± 0.0053
strict-match 5 exact_match 0.9606 ± 0.0054

PR

$ vllm serve Qwen/Qwen3-30B-A3B
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8620 ± 0.0095
strict-match 5 exact_match 0.8946 ± 0.0085
$ vllm serve RedHatAI/Meta-Llama-3.1-70B-FP8 -tp=2
local-completions ({'pretrained': 'RedHatAI/Meta-Llama-3.1-70B-FP8', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8029 ± 0.0110
strict-match 5 exact_match 0.8044 ± 0.0109
$ vllm serve deepseek-ai/DeepSeek-V3.1 -dp 8 -ep
local-completions ({'pretrained': 'deepseek-ai/DeepSeek-V3.1', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9583 ± 0.0055
strict-match 5 exact_match 0.9575 ± 0.0056

lm_eval (H100)

main

$ vllm serve Qwen/Qwen3-30B-A3B-FP8 -dp 2 -ep -tp 2
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B-FP8', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8234 ± 0.0105
strict-match 5 exact_match 0.8961 ± 0.0084

PR

$ vllm serve Qwen/Qwen3-30B-A3B-FP8 -dp 2 -ep -tp 2
local-completions ({'pretrained': 'Qwen/Qwen3-30B-A3B-FP8', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8271 ± 0.0104
strict-match 5 exact_match 0.8931 ± 0.0085

lm_eval (MI355)

@mergify mergify bot added the ci/build label Feb 4, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 4, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the foundational skeleton for a new Intermediate Representation (IR) in vLLM and undertakes a significant refactoring of the compilation passes and their associated tests. The new IR system provides a clean and extensible way to define and register custom operations. The refactoring effort organizes compilation passes into a more structured passes subdirectory and overhauls the end-to-end fusion tests for improved maintainability. The CI configuration has also been updated accordingly to reflect these structural changes. My review found one area for improvement in the new IR implementation.

Comment thread vllm/ir/op.py Outdated
@mergify mergify bot removed the needs-rebase label Feb 5, 2026
@ProExpertProg ProExpertProg added torch.compile vllm-ir vLLM IR: intermediate representation and kernel registration and removed ci/build labels Feb 5, 2026
@ProExpertProg ProExpertProg moved this from To triage to In progress in torch.compile integration Feb 5, 2026
@ProExpertProg ProExpertProg force-pushed the luka/vllm-ir/rms-norm branch 7 times, most recently from bc6dbee to 289444c Compare February 6, 2026 23:46
@mergify mergify bot added the ci/build label Feb 7, 2026
@ProExpertProg ProExpertProg force-pushed the luka/vllm-ir/rms-norm branch 2 times, most recently from 2e45e02 to e3227db Compare February 7, 2026 13:50
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 7, 2026
@mergify mergify bot added nvidia rocm Related to AMD ROCm labels Feb 7, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 7, 2026
@ProExpertProg ProExpertProg marked this pull request as ready for review February 7, 2026 14:16
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 23, 2026
# Conflicts:
#	vllm/config/kernel.py
#	vllm/model_executor/layers/layernorm.py
#	vllm/platforms/interface.py
Comment on lines +245 to +247
return ir.ops.rms_norm(
x, self.weight, self.variance_epsilon, self.variance_size_override
)
Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude said (and I think is reasonable): use self.weight.data here. Previously when Dynamo was tracing through this that is what was being used (

self.weight.data if self.has_weight else None,
).

This fixes the issue with torch.no_grad(). I'm not sure if there's something larger that is wrong, but changing this back to self.weight.data gets us back the previous behavior

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow nice find!

Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I think claude figured out the torch.no_grad() issue, we should fix that before merging

@mergify mergify bot removed the needs-rebase label Mar 26, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 26, 2026

Hi @ProExpertProg, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
@ProExpertProg ProExpertProg force-pushed the luka/vllm-ir/rms-norm branch from c59d43f to d98bda6 Compare March 26, 2026 14:53
@mergify mergify bot added the intel-gpu Related to Intel GPU label Mar 27, 2026
@wendyliu235
Copy link
Copy Markdown
Contributor

@ProExpertProg We added a new Intel CI pipeline that only gates Intel PRs, so it does not apply to your PR. Feel free to ignore the result.

@ProExpertProg
Copy link
Copy Markdown
Collaborator Author

ProExpertProg commented Mar 27, 2026

Benchmarking results for Qwen-0.6B on B200:

main (7d6917bef5)

$ vllm bench latency --model=Qwen/Qwen3-0.6B --enforce-eager
Avg latency: 1.1205002814996987 seconds
10% percentile latency: 1.0299688416067512 seconds
25% percentile latency: 1.03302774974145 seconds
50% percentile latency: 1.037226719490718 seconds
75% percentile latency: 1.0453055927646346 seconds
90% percentile latency: 1.1188749122084127 seconds
99% percentile latency: 1.9936111562442966 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B -cc.cudagraph_mode=FULL_DECODE_ONLY -cc.mode=NONE
Avg latency: 0.23219784893250714 seconds
10% percentile latency: 0.2200795673765242 seconds
25% percentile latency: 0.22023575179628097 seconds
50% percentile latency: 0.22870584204792976 seconds
75% percentile latency: 0.2294690987037029 seconds
90% percentile latency: 0.2681176986428909 seconds
99% percentile latency: 0.29146818072302266 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B -cc.cudagraph_mode=FULL_DECODE_ONLY
Avg latency: 0.23715687717388695 seconds
10% percentile latency: 0.19969781424151734 seconds
25% percentile latency: 0.20861879669246264 seconds
50% percentile latency: 0.2095698295161128 seconds
75% percentile latency: 0.2866470882727299 seconds
90% percentile latency: 0.29861339112976565 seconds
99% percentile latency: 0.30874248172855007 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B 
Avg latency: 0.19857681954745204 seconds
10% percentile latency: 0.1949726556893438 seconds
25% percentile latency: 0.1952334477682598 seconds
50% percentile latency: 0.1973732184851542 seconds
75% percentile latency: 0.20090638150577433 seconds
90% percentile latency: 0.20354869603179396 seconds
99% percentile latency: 0.20656762834638356 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B -cc.custom_ops+=+rms_norm
Avg latency: 0.21674859452759848 seconds
10% percentile latency: 0.21210689584258943 seconds
25% percentile latency: 0.2122281797346659 seconds
50% percentile latency: 0.2141083375317976 seconds
75% percentile latency: 0.21661500504706055 seconds
90% percentile latency: 0.21756883854977788 seconds
99% percentile latency: 0.25163014805410056 seconds

PR

$ vllm bench latency --model=Qwen/Qwen3-0.6B --enforce-eager
Avg latency: 1.009949086928585 seconds
10% percentile latency: 1.0030865375651046 seconds
25% percentile latency: 1.006905260932399 seconds
50% percentile latency: 1.0103985135210678 seconds
75% percentile latency: 1.0125141454918776 seconds
90% percentile latency: 1.0159681793185882 seconds
99% percentile latency: 1.0241237908799667 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B -cc.cudagraph_mode=FULL_DECODE_ONLY -cc.mode=NONE
Avg latency: 0.2692852795124054 seconds
10% percentile latency: 0.22030140799470246 seconds
25% percentile latency: 0.2212763840216212 seconds
50% percentile latency: 0.22949880751548335 seconds
75% percentile latency: 0.2765023543033749 seconds
90% percentile latency: 0.4044640638865531 seconds
99% percentile latency: 0.44274910243111665 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B -cc.cudagraph_mode=FULL_DECODE_ONLY
Avg latency: 0.2334265397201913 seconds
10% percentile latency: 0.19989175074733795 seconds
25% percentile latency: 0.20082748675486073 seconds
50% percentile latency: 0.21009170851903036 seconds
75% percentile latency: 0.22850232501514256 seconds
90% percentile latency: 0.309309652983211 seconds
99% percentile latency: 0.36086777127929964 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B 
Avg latency: 0.19729100486729295 seconds
10% percentile latency: 0.19479844837915153 seconds
25% percentile latency: 0.19507092799176462 seconds
50% percentile latency: 0.19574911502422765 seconds
75% percentile latency: 0.20004481880459934 seconds
90% percentile latency: 0.2008681818842888 seconds
99% percentile latency: 0.20560357752256095 seconds

$ vllm bench latency --model=Qwen/Qwen3-0.6B -cc.custom_ops+=+rms_norm --ir-op-priority.rms_norm=vllm_c
Avg latency: 0.21465967597517496 seconds
10% percentile latency: 0.2118833897053264 seconds
25% percentile latency: 0.21195765974698588 seconds
50% percentile latency: 0.21594521502265707 seconds
75% percentile latency: 0.2165221305040177 seconds
90% percentile latency: 0.21707573987077922 seconds
99% percentile latency: 0.21861423314083367 seconds

@ProExpertProg
Copy link
Copy Markdown
Collaborator Author

main:

$ vllm serve openai/gpt-oss-120b -dp=2 -tp=2 -ep
local-completions ({'pretrained': 'openai/gpt-oss-120b', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 50, 'max_retries': 3}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.4610 ± 0.0137
strict-match 5 exact_match 0.2904 ± 0.0125

@micah-wil
Copy link
Copy Markdown
Contributor

I checked out this PR on MI355, here are my findings:

  • Unit tests check out on both MI300 and MI355

    • Failures in test_tp1_quant.py and test_fusion.py but they are also present on main

DeepSeek-R1:

  • Perf w/ AITER (via vllm bench serve and vllm bench latency) is the same with and without vLLM IR
  • Accuracy w/ AITER (via lm_eval) is the same with and without vLLM IR (got a score of 0.96)
  • I looked at traces with/without vllm IR and they were practically identical (most notably the rms norm fusions were applied correctly).

Llama-3.1-70B-Instruct

  • Perf with AITER (via vllm bench serve and vllm bench latency) is the same with and without vLLM IR
  • Perf without AITER is the same with and without vLLM IR
  • Accuracy with and without AITER (via lm_eval) is the same with and without vLLM IR (got scores around 0.88)
  • enforce-eager test with AITER enabled VLLM_ROCM_USE_AITER=1 vllm bench latency /models/Llama-3.1-70B-Instruct --enforce-eager :
    • W/ vLLM IR: Avg latency: 46.51246352894232 seconds
    • Without vLLM IR: Avg latency: 45.510956873605025 seconds
    • The difference here appears to just be noise or maybe something to do with the node I tested on. I repeated this test and this 1s difference is not consistently reproducible. I also profiled & compared with/without vllm IR and didn't see any issues.
  • enforce-eager test with AITER disabled vllm bench latency /models/Llama-3.1-70B-Instruct --enforce-eager :
    • W/ vLLM IR: Avg latency: 56.540494579216464 seconds
    • Without vLLM IR: Avg latency: 56.73963632159867 seconds

Tldr: I think everything looks good from my initial look

Signed-off-by: Luka Govedic <luka.govedic@gmail.com>

# Conflicts:
#	vllm/compilation/passes/pass_manager.py
@ProExpertProg ProExpertProg merged commit 40bb175 into vllm-project:main Apr 1, 2026
181 of 182 checks passed
@github-project-automation github-project-automation bot moved this from In progress to Done in torch.compile integration Apr 1, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Apr 1, 2026
@gmagogsfm
Copy link
Copy Markdown
Contributor

Woohoo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build intel-gpu Related to Intel GPU nvidia ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs rocm Related to AMD ROCm torch.compile vllm-ir vLLM IR: intermediate representation and kernel registration

Projects

Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.