Skip to content

Feature/silu block quant fusion v1#32996

Merged
ProExpertProg merged 15 commits intovllm-project:mainfrom
Monishver11:feature/silu-block-quant-fusion-v1
Apr 1, 2026
Merged

Feature/silu block quant fusion v1#32996
ProExpertProg merged 15 commits intovllm-project:mainfrom
Monishver11:feature/silu-block-quant-fusion-v1

Conversation

@Monishver11
Copy link
Copy Markdown
Contributor

@Monishver11 Monishver11 commented Jan 24, 2026

Purpose

CUDA kernel and pattern matching for Fused SiluMul+Groupwise FP8-Quantization. For #27847

Test Result

The experiments are done on NVIDIA GeForce RTX 4070 and CUDA Version: 13.0.

Test fused op:

pytest tests/kernels/core/test_fused_silu_mul_block_quant.py

(vllm-dev) [mc10322@cuda5 vllm]$ pytest tests/kernels/core/test_fused_silu_mul_block_quant.py
============================================================ test session starts ============================================================
platform linux -- Python 3.10.19, pytest-8.3.5, pluggy-1.6.0
rootdir: /scratch/mc10322/vllm
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 330 items

tests/kernels/core/test_fused_silu_mul_block_quant.py ............................................................................... [ 23%]
..................................................................................................................................... [ 64%]
......................................................................................................................                [100%]

============================================================= warnings summary ==============================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================ 330 passed, 2 warnings in 98.97s (0:01:38) =================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Microbenchmark isolated op:

python benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py

[------------------------------------------------------ silu-mul-block-quant ------------------------------------------------------]
                                                     |  unfused_fp8_impl  |  unfused_groupwise_fp8_impl  |  fused_groupwise_fp8_impl
1 threads: -------------------------------------------------------------------------------------------------------------------------
      N 16 x D 1024 x DT torch.float16 x GS 64       |       278.0        |            321.6             |           133.4
      N 16 x D 1024 x DT torch.float16 x GS 128      |       278.7        |            320.1             |           133.6
      N 16 x D 1024 x DT torch.bfloat16 x GS 64      |       279.5        |            321.2             |           133.5
      N 16 x D 1024 x DT torch.bfloat16 x GS 128     |       278.9        |            321.3             |           133.1
      N 16 x D 2048 x DT torch.float16 x GS 64       |       277.5        |            325.1             |           133.0
      N 16 x D 2048 x DT torch.float16 x GS 128      |       278.9        |            321.6             |           133.1
      N 16 x D 2048 x DT torch.bfloat16 x GS 64      |       277.6        |            320.2             |           133.2
      N 16 x D 2048 x DT torch.bfloat16 x GS 128     |       278.7        |            320.4             |           133.8
      N 16 x D 4096 x DT torch.float16 x GS 64       |       278.5        |            320.7             |           133.0
      N 16 x D 4096 x DT torch.float16 x GS 128      |       279.6        |            321.3             |           133.4
      N 16 x D 4096 x DT torch.bfloat16 x GS 64      |       278.3        |            321.4             |           132.5
      N 16 x D 4096 x DT torch.bfloat16 x GS 128     |       277.9        |            322.5             |           132.3
      N 16 x D 5120 x DT torch.float16 x GS 64       |       277.4        |            319.6             |           132.5
      N 16 x D 5120 x DT torch.float16 x GS 128      |       278.1        |            320.2             |           132.5
      N 16 x D 5120 x DT torch.bfloat16 x GS 64      |       276.9        |            319.2             |           132.7
      N 16 x D 5120 x DT torch.bfloat16 x GS 128     |       277.1        |            319.5             |           132.7
      N 16 x D 14336 x DT torch.float16 x GS 64      |       277.6        |            319.1             |           132.2
      N 16 x D 14336 x DT torch.float16 x GS 128     |       277.5        |            319.5             |           132.5
      N 16 x D 14336 x DT torch.bfloat16 x GS 64     |       278.0        |            321.8             |           131.9
      N 16 x D 14336 x DT torch.bfloat16 x GS 128    |       277.0        |            321.3             |           132.1
      N 128 x D 1024 x DT torch.bfloat16 x GS 64     |       276.8        |            318.4             |           132.1
      N 128 x D 1024 x DT torch.bfloat16 x GS 128    |       283.5        |            317.0             |           131.9
      N 128 x D 2048 x DT torch.float16 x GS 64      |       275.4        |            316.7             |           131.5
      N 128 x D 2048 x DT torch.float16 x GS 128     |       274.8        |            322.1             |           131.1
      N 128 x D 2048 x DT torch.bfloat16 x GS 64     |       274.7        |            316.2             |           131.2
      N 128 x D 2048 x DT torch.bfloat16 x GS 128    |       273.0        |            317.3             |           130.8
      N 128 x D 4096 x DT torch.float16 x GS 64      |       273.6        |            316.1             |           130.9
      N 128 x D 4096 x DT torch.float16 x GS 128     |       274.6        |            315.7             |           131.0
      N 128 x D 4096 x DT torch.bfloat16 x GS 64     |       273.7        |            315.0             |           130.4
      N 128 x D 4096 x DT torch.bfloat16 x GS 128    |       272.4        |            314.9             |           130.2
      N 128 x D 5120 x DT torch.float16 x GS 64      |       273.2        |            315.4             |           130.7
      N 128 x D 5120 x DT torch.float16 x GS 128     |       273.4        |            315.5             |           130.4
      N 128 x D 5120 x DT torch.bfloat16 x GS 64     |       272.9        |            314.2             |           130.9
      N 128 x D 5120 x DT torch.bfloat16 x GS 128    |       275.3        |            315.3             |           130.1
      N 128 x D 14336 x DT torch.float16 x GS 64     |       274.1        |            316.3             |           130.7
      N 128 x D 14336 x DT torch.float16 x GS 128    |       274.2        |            318.3             |           130.6
      N 128 x D 14336 x DT torch.bfloat16 x GS 64    |       273.9        |            316.3             |           130.7
      N 128 x D 14336 x DT torch.bfloat16 x GS 128   |       274.6        |            315.8             |           130.5
      N 512 x D 1024 x DT torch.float16 x GS 64      |       271.9        |            313.7             |           130.1
      N 512 x D 1024 x DT torch.float16 x GS 128     |       271.0        |            313.2             |           130.5
      N 512 x D 1024 x DT torch.bfloat16 x GS 64     |       270.6        |            312.2             |           129.2
      N 512 x D 1024 x DT torch.bfloat16 x GS 128    |       271.0        |            313.5             |           129.9
      N 512 x D 2048 x DT torch.float16 x GS 64      |       270.4        |            314.2             |           130.2
      N 512 x D 2048 x DT torch.float16 x GS 128     |       271.2        |            313.4             |           129.9
      N 512 x D 2048 x DT torch.bfloat16 x GS 64     |       271.7        |            312.3             |           130.2
      N 512 x D 2048 x DT torch.bfloat16 x GS 128    |       271.0        |            313.7             |           129.7
      N 512 x D 4096 x DT torch.float16 x GS 64      |       272.0        |            313.9             |           129.9
      N 512 x D 4096 x DT torch.float16 x GS 128     |       272.1        |            315.3             |           130.8
      N 512 x D 4096 x DT torch.bfloat16 x GS 64     |       270
      N 512 x D 5120 x DT torch.bfloat16 x GS 128    |       271.3        |            313.5             |           129.4
      N 512 x D 14336 x DT torch.float16 x GS 64     |       274.0        |            316.3             |           130.4
      N 512 x D 14336 x DT torch.float16 x GS 128    |       272.6        |            313.3             |           130.4
      N 512 x D 14336 x DT torch.bfloat16 x GS 64    |       273.0        |            314.7             |           130.1
      N 512 x D 14336 x DT torch.bfloat16 x GS 128   |       273.6        |            315.2             |           129.5
      N 2048 x D 1024 x DT torch.float16 x GS 64     |       270.2        |            313.2             |           130.0
      N 2048 x D 1024 x DT torch.float16 x GS 128    |       271.1        |            313.3             |           129.7
      N 2048 x D 1024 x DT torch.bfloat16 x GS 64    |       269.5        |            312.0             |           129.5
      N 2048 x D 1024 x DT torch.bfloat16 x GS 128   |       269.9        |            340.8             |           128.9
      N 2048 x D 2048 x DT torch.float16 x GS 64     |       271.3        |            313.1             |           129.0
      N 2048 x D 2048 x DT torch.float16 x GS 128    |       270.7        |            312.4             |           128.9
      N 2048 x D 2048 x DT torch.bfloat16 x GS 64    |       271.2        |            312.1             |           129.2
      N 2048 x D 2048 x DT torch.bfloat16 x GS 128   |       270.7        |            312.6             |           128.3
      N 2048 x D 4096 x DT torch.float16 x GS 64     |       270.8        |            313.7             |           140.6
      N 2048 x D 4096 x DT torch.float16 x GS 128    |       272.2        |            313.0             |           140.2
      N 2048 x D 4096 x DT torch.bfloat16 x GS 64    |       271.3        |            313.5             |           140.9
      N 2048 x D 4096 x DT torch.bfloat16 x GS 128   |       272.9        |            313.9             |           140.1
      N 2048 x D 5120 x DT torch.float16 x GS 64     |       293.2        |            333.5             |           180.2
      N 2048 x D 5120 x DT torch.float16 x GS 128    |       294.0        |            313.0             |           178.8
      N 2048 x D 5120 x DT torch.bfloat16 x GS 64    |       294.7        |            312.8             |           181.0
      N 2048 x D 5120 x DT torch.bfloat16 x GS 128   |       294.8        |            315.1             |           178.4
      N 2048 x D 14336 x DT torch.float16 x GS 64    |       997.6        |            971.5             |           503.7
      N 2048 x D 14336 x DT torch.float16 x GS 128   |       997.3        |            847.3             |           499.4
      N 2048 x D 14336 x DT torch.bfloat16 x GS 64   |       997.4        |            854.7             |           503.7
      N 2048 x D 14336 x DT torch.bfloat16 x GS 128  |       997.2        |            846.5             |           499.2

Times are in microseconds (us).

Compilation pattern matching of the fused op:

(vllm) [mc10322@cuda5 vllm]$ pytest tests/compile/passes/test_silu_mul_quant_fusion.py
/scratch/mc10322/vllm/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=================================== test session starts ====================================
platform linux -- Python 3.12.7, pytest-8.3.5, pluggy-1.5.0
rootdir: /scratch/mc10322/vllm
configfile: pyproject.toml
plugins: schemathesis-3.39.15, hydra-core-1.3.2, shard-0.1.2, subtests-0.14.1, hypothesis-6.131.0, mock-3.14.0, cov-6.3.0, forked-1.6.0, asyncio-0.24.0, rerunfailures-14.0, timeout-2.3.1, buildkite-test-collector-0.1.9, anyio-4.12.1
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 160 items
Running 160 items in this shard

tests/compile/passes/test_silu_mul_quant_fusion.py ................................. [ 20%]
...............................................................sssssssssssssssssssss [ 73%]
sssssssssss........ssssssss........ssssssss                                          [100%]

lm_eval & Benchmarks

Model: Qwen2.5-0.5B-Instruct (FP8_BLOCK quantized via llm-compressor)
GPU: NVIDIA RTX 4070 (12GB)

Note: Used a 0.5B model due to GPU memory constraints. Fusion impact would be more pronounced on larger models with wider MLP intermediate sizes.

lm_eval (gsm8k, 5-shot, 250 samples)

fusion disabled:

vllm ({'pretrained': './Qwen2.5-0.5B-FP8-Block', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'gpu_memory_utilization': 0.9, 'compilation_config': {'pass_config': {'fuse_act_quant': False}}}), gen_kwargs: ({}), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.348|±  |0.0302|
|     |       |strict-match    |     5|exact_match|↑  |0.336|±  |0.0299|

fusion enabled:

vllm ({'pretrained': './Qwen2.5-0.5B-FP8-Block', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'gpu_memory_utilization': 0.9, 'compilation_config': {'custom_ops': ['+silu_and_mul'], 'pass_config': {'fuse_act_quant': True}}}), gen_kwargs: ({}), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.34|±  |0.0300|
|     |       |strict-match    |     5|exact_match|↑  | 0.32|±  |0.0296|

+silu_and_muldisabled, fusion enabled:

vllm ({'pretrained': './Qwen2.5-0.5B-FP8-Block', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'gpu_memory_utilization': 0.9}), gen_kwargs: ({}), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.34|±  |0.0300|
|     |       |strict-match    |     5|exact_match|↑  | 0.32|±  |0.0296|

Results within error bars — no accuracy degradation.

Serving Benchmark (sonnet, 640 prompts, 128 RPS)

default (no +silu_and_mul, equivalent to main branch for FP8_BLOCK models):

============ Serving Benchmark Result ============
Successful requests:                     640
Failed requests:                         0
Request rate configured (RPS):           128.00
Benchmark duration (s):                  20.28
Total input tokens:                      162448
Total generated tokens:                  73031
Request throughput (req/s):              31.55
Output token throughput (tok/s):         3600.37
Peak output token throughput (tok/s):    9956.00
Peak concurrent requests:                640.00
Total token throughput (tok/s):          11608.92
---------------Time to First Token----------------
Mean TTFT (ms):                          12791.15
Median TTFT (ms):                        12855.42
P99 TTFT (ms):                           15077.32
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          26.35
Median TPOT (ms):                        26.41
P99 TPOT (ms):                           43.45
---------------Inter-token Latency----------------
Mean ITL (ms):                           28.84
Median ITL (ms):                         26.33
P99 ITL (ms):                            74.21
==================================================

+silu_and_mul enabled, fusion disabled:

============ Serving Benchmark Result ============
Successful requests:                     640
Failed requests:                         0
Request rate configured (RPS):           128.00
Benchmark duration (s):                  10.56
Total input tokens:                      162448
Total generated tokens:                  70978
Request throughput (req/s):              60.59
Output token throughput (tok/s):         6719.51
Peak output token throughput (tok/s):    8256.00
Peak concurrent requests:                613.00
Total token throughput (tok/s):          22098.53
---------------Time to First Token----------------
Mean TTFT (ms):                          2615.44
Median TTFT (ms):                        2281.38
P99 TTFT (ms):                           5892.83
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.70
Median TPOT (ms):                        27.69
P99 TPOT (ms):                           62.81
---------------Inter-token Latency----------------
Mean ITL (ms):                           34.05
Median ITL (ms):                         30.34
P99 ITL (ms):                            159.77
==================================================

+silu_and_mul enabled, fusion enabled:

============ Serving Benchmark Result ============
Successful requests:                     640
Failed requests:                         0
Request rate configured (RPS):           128.00
Benchmark duration (s):                  8.30
Total input tokens:                      162448
Total generated tokens:                  71908
Request throughput (req/s):              77.14
Output token throughput (tok/s):         8666.61
Peak output token throughput (tok/s):    7460.00
Peak concurrent requests:                406.00
Total token throughput (tok/s):          28245.44
---------------Time to First Token----------------
Mean TTFT (ms):                          1278.96
Median TTFT (ms):                        962.05
P99 TTFT (ms):                           3750.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.20
Median TPOT (ms):                        14.47
P99 TPOT (ms):                           29.54
---------------Inter-token Latency----------------
Mean ITL (ms):                           26.94
Median ITL (ms):                         19.31
P99 ITL (ms):                            109.77
==================================================

+silu_and_muldisabled, fusion enabled:

Maximum request concurrency: None
100%|██████████| 640/640 [00:09<00:00, 69.29it/s]
Request rate configured (RPS):           128.00
Benchmark duration (s):                  9.24
Total input tokens:                      162448
Total generated tokens:                  71084
Request throughput (req/s):              69.28
Output token throughput (tok/s):         7695.29
Peak output token throughput (tok/s):    8513.00
Peak concurrent requests:                517.00
Total token throughput (tok/s):          25281.30
---------------Time to First Token----------------
Mean TTFT (ms):                          1500.53
Median TTFT (ms):                        958.20
P99 TTFT (ms):                           4212.73
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          23.35
Median TPOT (ms):                        23.98
P99 TPOT (ms):                           38.88
---------------Inter-token Latency----------------
Mean ITL (ms):                           30.95
Median ITL (ms):                         28.44
P99 ITL (ms):                            88.56
==================================================

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@Monishver11 Monishver11 marked this pull request as draft January 24, 2026 06:36
@mergify mergify bot added ci/build performance Performance-related issues labels Jan 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused CUDA kernel for SiLU, multiplication, and block-wise FP8 quantization, along with corresponding benchmarks, tests, and integration into the torch.compile fusion passes. The new kernel shows significant performance improvements in the provided benchmarks.

My review has identified a couple of important issues:

  1. A critical issue in the torch.compile fusion pass where the pattern for the new fused kernel is hardcoded for a single group_size, which will prevent fusion for other supported sizes.
  2. A high-severity issue in the CUDA kernel implementation regarding a hardcoded shared memory size, which makes the code brittle and prone to future bugs.

Addressing these points will improve the correctness and maintainability of the new feature. The rest of the changes, including the tests and benchmark code, look solid.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@Monishver11
Copy link
Copy Markdown
Contributor Author

Hello @ProExpertProg. I've created the kernel for SiluMul+BlockQuant fusion, and it's working fine(yet, not performant enough). I'm still having some issues with the fusion pass and pattern matching, which I'll be working on. I want to get some feedback on the kernel and how you think it can be made more optimized and efficient.

@ElizaWszola, I used your #27883 PR as a good reference to get some understanding of the internal workings. Thanks for it. And, if you can also share some review on the kernel, what I missed, etc., it'll be really helpful.

I see and fix the ones raised by the bots shortly.

@Monishver11
Copy link
Copy Markdown
Contributor Author

Hello @ProExpertProg, quick follow-up. When you have some time, let me know your thoughts on the kernel part.

@Monishver11
Copy link
Copy Markdown
Contributor Author

Hello @ProExpertProg, I've updated my kernel, and now it's performing better than the unfused implementation. Can you please comment on this and share your review? Thanks.

@weimin023
Copy link
Copy Markdown

Hello @ProExpertProg,
Triton kernel also needs your comments in #33026, thanks!

@Monishver11 Monishver11 marked this pull request as ready for review February 12, 2026 01:48
@Monishver11
Copy link
Copy Markdown
Contributor Author

Hello @ProExpertProg,

I'm having some trouble with the fusion pattern match. Could you please provide some guidance on this? I have tried using the existing matchers for silu_mul and block_quant, as well as expressing silu_mul inline, but I still can’t get it to match the pattern for replacement with the fused kernel.

Additionally, I noticed that the silu_and_mul_per_block_quant_kernel_large kernel performs well across all cases, as it parallelizes computations on a per-token and per-group basis. If you have any suggestions for further optimizations, I would greatly appreciate it. Thank you.

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 12, 2026

Hi @Monishver11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Monishver11.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 12, 2026
@ProExpertProg
Copy link
Copy Markdown
Collaborator

Have you tried using the VLLM_PATTERN_MATCH_DEBUG env variable? You can set them at to the name of the node in the graph you expect to match in (node of the first return from the pattern).

@mergify mergify bot removed the needs-rebase label Feb 16, 2026
@Monishver11 Monishver11 force-pushed the feature/silu-block-quant-fusion-v1 branch from e18d654 to a528b86 Compare February 16, 2026 00:39
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 16, 2026

Hi @Monishver11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 26, 2026

Hi @Monishver11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@Monishver11
Copy link
Copy Markdown
Contributor Author

Hello @ProExpertProg. Thanks for your last comment. I've now fixed the kernel fusion pattern match. Can you please review this PR?

@mergify mergify bot added frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models qwen Related to Qwen models gpt-oss Related to GPT-OSS models nvidia labels Mar 27, 2026
@mergify mergify bot added the rocm Related to AMD ROCm label Mar 27, 2026
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Mar 27, 2026
@ProExpertProg
Copy link
Copy Markdown
Collaborator

One more request actually: could you update docs/design/fusions.md to mention this kernel is now supported? And can you check that this fusion is enabled by default for applicable models?

Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 31, 2026

Documentation preview: https://vllm--32996.org.readthedocs.build/en/32996/

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 31, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Monishver11.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@Monishver11
Copy link
Copy Markdown
Contributor Author

Hello @ProExpertProg. I've updated docs/design/fusions.md with this new kernel. And for the default enablement: yes, this fusion is automatically enabled at O1+ via enable_act_fusion, which returns True when either silu_and_mul or quant_fp8 custom ops are active. No config changes needed.

@gshtras
Copy link
Copy Markdown
Collaborator

gshtras commented Apr 1, 2026

Breaks on ROCm
AttributeError: '_OpNamespace' '_C' object has no attribute 'silu_and_mul'

@Monishver11
Copy link
Copy Markdown
Contributor Author

Breaks on ROCm AttributeError: '_OpNamespace' '_C' object has no attribute 'silu_and_mul'

Thanks for the fix @gshtras.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models intel-gpu Related to Intel GPU kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants