Skip to content

Conversation

@XiaobingSuper
Copy link
Contributor

@XiaobingSuper XiaobingSuper commented Oct 15, 2025

This PR is to make the nvfp4 moe kernel work for TP>2. We don't need to check the last dim(K) of the quantized weight is divided by 16, the original weight has guaranteed it.

Purpose

Test Plan

import numpy
import torch

from vllm import LLM, SamplingParams


prompts = [
    "The Swiss Alps are", 
    "Brad Marchand is",
    "The Boston Bruins are"
]


sampling_params = SamplingParams(temperature=0.90, top_p=0.95, max_tokens=40, min_tokens=10)
llm = LLM("nv-community/Qwen3-30B-A3B-FP4", tensor_parallel_size=2, enforce_eager=True)
#llm = LLM("nv-community/Qwen3-30B-A3B-FP4", tensor_parallel_size=4, enforce_eager=True)

responses = llm.generate(prompts, sampling_params)
for output in responses:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Test Result

Prompt: 'The Swiss Alps are', Generated text: ' a mountainous region in the southern part of Switzerland, known for their stunning natural beauty and world-class ski resorts. The region is home to several major mountain ranges, including the Bernese Alps, the'
Prompt: 'Brad Marchand is', Generated text: ' a key player for the Boston Bruins, and his performance has been notable in various aspects of the game. In the 2022-23 season, Marchand played 82 games'
Prompt: 'The Boston Bruins are', Generated text: ' a professional ice hockey team based in Boston, Massachusetts, and they are one of the original National Hockey League (NHL) franchises. The team was founded in 1924, making them'

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: XiaobingSuper <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables tensor parallelism for nv-fp4 MoE layers by removing several assertions in ModelOptNvFp4FusedMoE.process_weights_after_loading. The removed assertions were either redundant or overly strict, preventing the use of tensor parallelism (TP > 1) on certain models.

Specifically:

  • The dtype assertions for w13_weight_scale and w2_weight_scale were redundant, as the swizzle_blockscale function already performs this check.
  • The shape assertions, which required the last dimension of the weight scales to be divisible by 16, were too strict. With tensor parallelism, this dimension can be sharded and may not remain divisible by 16. The swizzle_blockscale function correctly handles this by padding the dimension to a multiple of 4, which is what the underlying kernel requires.

The changes are correct and necessary to support tensor parallelism for these quantized MoE layers. The code is now more robust and consistent with the actual kernel requirements.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these asserts aren't present in the compressed-tensors backend and I was able to run vllm serve RedHatAI/Qwen3-30B-A3B-NVFP4 -tp 2 fine, so I think this was just an oversight for modelopt

# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 15, 2025
@mgoin mgoin changed the title make TP>1 works for nv-fp4 moe [ModelOpt] Remove NVFP4 MoE K%16==0 constraint Oct 15, 2025
Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it's the weight that should be divisble by 16, not the weight scale. Why are we removing the dtype assertion?

Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the dtype check is happening in swizzle blockscale.

@mgoin mgoin merged commit d796375 into vllm-project:main Oct 15, 2025
62 checks passed
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Oct 16, 2025
mandy-li pushed a commit to mandy-li/vllm that referenced this pull request Oct 16, 2025
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Nov 12, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants