Skip to content

Support sycl impl relu2_no_mul for NVIDIA-Nemotron-3-Nano-30B-A3B-bf16#232

Merged
xinyu-intel merged 3 commits intovllm-project:mainfrom
Dboyqiao:dev/zhefeng/relu2_no_mul
Apr 8, 2026
Merged

Support sycl impl relu2_no_mul for NVIDIA-Nemotron-3-Nano-30B-A3B-bf16#232
xinyu-intel merged 3 commits intovllm-project:mainfrom
Dboyqiao:dev/zhefeng/relu2_no_mul

Conversation

@Dboyqiao
Copy link
Copy Markdown
Contributor

@Dboyqiao Dboyqiao commented Mar 29, 2026

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

Support sycl kernel relu2_no_mul, this is an enhancement PR of #200.
compared to torch implementation, It has ~55% improvement: ~44us vs ~100us on ops level, and about ~0.8% improvement in each step

Test Plan

python -m pytest tests/test_activation.py -v

Test Result

pass

(Optional) Documentation Update

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
Copilot AI review requested due to automatic review settings March 29, 2026 15:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a SYCL/XPU implementation of the relu2_no_mul activation and wires it into the fused MoE path to support NVIDIA-Nemotron-3-Nano-30B-A3B-bf16 more efficiently.

Changes:

  • Register a new XPU custom op relu2_no_mul and implement its SYCL kernel.
  • Extend xpu_fused_moe to route activation="relu2_no_mul" and adjust GEMM2’s K accordingly.
  • Add unit-test coverage for the standalone activation op.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
vllm_xpu_kernels/fused_moe_interface.py Adds relu2_no_mul activation handling and scales GEMM2 K by 2 for this activation.
csrc/activation.cpp Implements relu2_no_mul SYCL elementwise kernel and dispatch.
csrc/ops.h Declares relu2_no_mul C++ entrypoint.
csrc/torch_bindings.cpp Registers relu2_no_mul in the Torch extension library for XPU.
tests/ops/activation_op.py Adds Relu2NoMul CustomOp wrapper + PyTorch reference implementation.
tests/test_activation.py Extends activation tests to include relu2_no_mul.
tests/register_ops.py Adds a Python test wrapper that calls torch.ops._C.relu2_no_mul.
Comments suppressed due to low confidence (1)

vllm_xpu_kernels/fused_moe_interface.py:252

  • act_output is allocated unconditionally with shape (num_moe_inputs, inter_size) but the relu2_no_mul branch immediately replaces it with torch.empty_like(gemm1_output). This results in an extra large allocation per call for that activation; consider allocating act_output inside each activation branch (or computing the needed output shape first) to avoid the wasted allocation.
    inter_size_scale = 1
    # act
    act_output = torch.empty((num_moe_inputs, inter_size),
                             dtype=gemm1_output.dtype,
                             device=gemm1_output.device)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +281 to 282
K=inter_size * inter_size_scale,
num_experts=num_experts,
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K is derived from inter_size * inter_size_scale, but cutlass_grouped_gemm_interface does not validate tensor shapes and will trust the provided K. For relu2_no_mul, this makes correctness/safety depend on w2 actually having the matching K-dimension (2*inter_size for non-int4 layouts). Please add an explicit shape assertion for w2 (and possibly w13) for this activation to prevent out-of-bounds reads if a caller passes incompatible weights.

Copilot uses AI. Check for mistakes.
Comment thread vllm_xpu_kernels/fused_moe_interface.py Outdated
Comment on lines +259 to +262
elif activation == "relu2_no_mul":
act_output = torch.empty_like(gemm1_output)
torch.ops._C.relu2_no_mul(act_output, gemm1_output)
inter_size_scale = 2
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new relu2_no_mul activation path in xpu_fused_moe isn’t covered by the existing fused-MoE test suite (current tests/fused_moe/test_fused_moe.py cases all use activation="silu"). Please add a fused-MoE unit test that runs xpu_fused_moe(..., activation="relu2_no_mul") with appropriately-shaped w2 and compares against a reference implementation, so regressions in the K-scaling/activation behavior are caught.

Copilot uses AI. Check for mistakes.
@xuechendi
Copy link
Copy Markdown
Collaborator

@jikunshang , please help to review

Comment thread vllm_xpu_kernels/fused_moe_interface.py Outdated
elif activation == "swigluoai" or ("SWIGLUOAI" in str(activation)):
torch.ops._C.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
elif activation == "relu2_no_mul":
act_output = torch.empty_like(gemm1_output)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

act_ouput shape in relu2_no_mul is different with that in XXX_and_mul, can not reuse the definition at line 250.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd better fix on L250?

Comment thread csrc/ops.h
double alpha = 1.702,
double limit = 7.0);

void relu2_no_mul(torch::Tensor& out, torch::Tensor& input);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: vllm repo doesn't have this cuda kernel yet. I prefer to make this into torch.ops._xpu_C as this is a xpu specifc kernel. (though it will not be used in vllm side yet).
keep it here is fine.

@Dboyqiao Dboyqiao force-pushed the dev/zhefeng/relu2_no_mul branch from 81203dc to bc0d9d3 Compare April 7, 2026 03:08
Comment thread vllm_xpu_kernels/fused_moe_interface.py Outdated
Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
@Dboyqiao Dboyqiao force-pushed the dev/zhefeng/relu2_no_mul branch from bc0d9d3 to 396a18b Compare April 7, 2026 05:13
@jikunshang
Copy link
Copy Markdown
Collaborator

pls rebase & fix conflicts

Signed-off-by: Zhefeng, Qiao <zhefeng.qiao@intel.com>
@xinyu-intel xinyu-intel merged commit eea548c into vllm-project:main Apr 8, 2026
8 checks passed
zufangzhu pushed a commit to zufangzhu/vllm-xpu-kernels that referenced this pull request Apr 8, 2026
vllm-project#232)

Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
jikunshang added a commit that referenced this pull request Apr 9, 2026
* [OneDNN] add mxfp8, mxfp4 onednn gemm  (#20)

* add mxfp4 onednn gemm

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* add ut for mx

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* fix

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* format with pre-commit

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* thanks copilot

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

---------

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* format

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* refine onednn gemm ut

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* skip scales check (#256)

Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* Support sycl impl relu2_no_mul for NVIDIA-Nemotron-3-Nano-30B-A3B-bf16 (#232)

Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

* Update test_fp8_gemm_onednn.py

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>

---------

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
Co-authored-by: root <root@emr813693.jf.intel.com>
Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
Co-authored-by: Zhefeng, Qiao <zhefeng.qiao@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants