Skip to content

[Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias#5721

Merged
yiz-liu merged 2 commits intovllm-project:mainfrom
ForBetterCodeNine:qv_norm_0108
Jan 22, 2026
Merged

[Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias#5721
yiz-liu merged 2 commits intovllm-project:mainfrom
ForBetterCodeNine:qv_norm_0108

Conversation

@ForBetterCodeNine
Copy link
Copy Markdown
Contributor

@ForBetterCodeNine ForBetterCodeNine commented Jan 8, 2026

What this PR does / why we need it?

This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the RFC:#4715

Does this PR introduce any user-facing change?

No

How was this patch tested?

llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces graph fusion optimizations for QKVNormRope and QKVNormRopeWithBias. While the fusion logic itself seems correct, the mechanism for registering these fusions is fragile. It relies on a global compiler instance to pass configuration at module import time, which is an anti-pattern that can lead to silent failures and makes the code difficult to maintain. My review provides critical feedback on this design and suggests a refactoring to an explicit registration pattern, which will improve the robustness and clarity of the code.

Comment on lines +36 to +48
# Global reference to the current compiler instance
_current_compiler = None


def set_current_compiler(compiler: "AscendCompiler") -> None:
"""Set the current compiler instance globally/"""
global _current_compiler
_current_compiler = compiler


def get_current_compiler() -> Optional["AscendCompiler"]:
"""Get the current compiler instance globally."""
return _current_compiler
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The introduction of a global _current_compiler to pass configuration to other modules is a fragile design choice. It creates hidden dependencies on execution order (e.g., AscendCompiler.compute_hash must be called before fusion modules are imported) and makes the code harder to reason about and test. A more robust approach would be to pass configuration explicitly, avoiding global state. I've added a concrete refactoring suggestion in vllm_ascend/compilation/npugraph_ex_passes/add_qvnorm_rope_fusion.py that would resolve this by triggering registration explicitly from npugraph_ex_compile.

Comment on lines +217 to +269
def get_qknorm_rope_vllm_config() -> VllmConfig:
"""
Get the vllm_config from the current compiler instance.
This function is called at module load time to lazily register fusions.

Returns:
VllmConfig object from the compiler, or a default one if not available.
"""
try:
from vllm_ascend.compilation.compiler_interface import get_current_compiler

compiler = get_current_compiler()
if compiler and hasattr(compiler, 'vllm_config'):
return compiler.vllm_config
except (ImportError, AttributeError):
pass

# Fallback to default config
return VllmConfig()


# Lazy initialization: try to register with compiler's vllm_config if available
try:
vllm_config = get_qknorm_rope_vllm_config()
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
vllm_config, Attention)

if len(attn_layers) == 0:
logger.debug(
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)

layer = next(iter(attn_layers.values()))
if layer.head_size != 128:
logger.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal to 128",
layer.head_size)

# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
logger.info(
f"Start register fusion pattern for QKNormRope with epsilons={eps}"
)
register_qknorm_rope_fusion(layer.head_size, layer.num_heads,
layer.num_kv_heads, eps)
register_qknorm_rope_fusion_with_bias(layer.head_size, layer.num_heads,
layer.num_kv_heads, eps)
except Exception as e:
logger.debug(
f"Failed to register QKNorm and Rope fusions at module load time: {e}. "
f"This is expected if the compiler hasn't been initialized yet."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This module-level fusion registration logic is fragile and has several issues:

  1. Dependency on global state: It relies on a global compiler instance being set before import. If not, it silently falls back to a default VllmConfig, which can lead to fusions being skipped or misconfigured without a clear error.
  2. Error masking: The broad except Exception hides underlying problems during registration, logging them only at the debug level. This can make debugging difficult.
  3. Potential crash: If attn_layers is empty, next(iter(attn_layers.values())) on line 249 will raise a StopIteration exception. This is currently masked by the broad exception but should be handled gracefully.

To make this more robust, I recommend replacing this block with an explicit registration function that is called from npugraph_ex_compile with the vllm_config. This removes the need for global state and makes dependencies clear.

_FUSIONS_REGISTERED = False


def def register_qkvnorm_rope_fusions(vllm_config: VllmConfig):
    """
    Registers QKNorm+Rope fusion patterns with torchair.
    This function is designed to be called explicitly once the vllm_config is available.
    """
    global _FUSIONS_REGISTERED
    if _FUSIONS_REGISTERED:
        return

    try:
        attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
            vllm_config, Attention)

        if not attn_layers:
            logger.debug(
                "QKNorm and Rope fusion enabled, but no Attention layers were discovered."
            )
            return

        layer = next(iter(attn_layers.values()))
        if layer.head_size != 128:
            logger.debug(
                "QKNorm and Rope fusion not enabled: head_dim %d is not equal to 128",
                layer.head_size)
            return

        # register converter for pass
        common_epsilons = [1e-5, 1e-6]
        for eps in common_epsilons:
            logger.info(
                f"Start register fusion pattern for QKNormRope with epsilons={eps}"
            )
            register_qknorm_rope_fusion(layer.head_size, layer.num_heads,
                                        layer.num_kv_heads, eps)
            register_qknorm_rope_fusion_with_bias(layer.head_size, layer.num_heads,
                                                layer.num_kv_heads, eps)
        _FUSIONS_REGISTERED = True
    except Exception as e:
        # Log as warning since this might be a significant issue for performance.
        logger.warning(f"Failed to register QKNorm and Rope fusions: {e}")

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jan 8, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@ForBetterCodeNine ForBetterCodeNine force-pushed the qv_norm_0108 branch 2 times, most recently from 370f3ad to 00b34e9 Compare January 8, 2026 08:54
"""
name = "AscendCompiler"

def compute_hash(self, vllm_config: VllmConfig) -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No usage of compute_hash() was observed. Why add it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's called by vllm

from vllm.logger import logger


def _extra_stream_scope_check(match: Match) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make this function public and reuse it?

extra_check=_extra_stream_scope_check)


def get_qknorm_rope_vllm_config() -> VllmConfig:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This writing style looks very inappropriate. It is recommended to add a PassManager.



# The replacement registered here will be actually executed after AOT.
def register_qknorm_rope_fusion_with_bias(head_dim, num_heads, num_kv_heads,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass should support dynamic enabling or disabling. Otherwise, some errors will be very difficult to locate.

return True


@functools.lru_cache(None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

@ForBetterCodeNine ForBetterCodeNine force-pushed the qv_norm_0108 branch 6 times, most recently from 1e43203 to e80bafa Compare January 13, 2026 03:13
from vllm.logger import logger


def _extra_stream_scope_check(match: Match) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the method is called by other files, the function should not start with an “_”

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {})
if self.ascend_compilation_config.get("graphex_norm_quant", True):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem necessary. You can directly use the attributes in the ascend_compilation_config, such as "fuse_norm_quant"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the current code, where 1/scale is no longer used, which would introduce two redundant operators, such as reciprocal into the graph.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right. According to your review comments, we have modified the code to use vllm.quantize, and replaced the expression 1. / scale here with scale directly. After validation on the Qwen3-8b-w8a8 and Qwen3-235B-A22B networks, we confirmed that there are no precision issues, and the redundant operators have also been eliminated.

@wxsIcey
Copy link
Copy Markdown
Collaborator

wxsIcey commented Jan 13, 2026

It is recommended to add an e2e test to intercept the issue of fusion failure.

@@ -0,0 +1,282 @@
#
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the file name wrong? It should be add_qknorm_rope_fusiong.py.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@wxsIcey
Copy link
Copy Markdown
Collaborator

wxsIcey commented Jan 15, 2026

LGTM. Can an e2e test be added to ensure the fusion is successful? like https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/compile/test_norm_quant_fusion.py

@ForBetterCodeNine
Copy link
Copy Markdown
Contributor Author

LGTM. Can an e2e test be added to ensure the fusion is successful? like https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/compile/test_norm_quant_fusion.py

OK.When the npu_graph_ex switch is enabled, we will supplement the end-to-end (e2e) test cases promptly to verify its availability.

@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Jan 15, 2026
@wangxiyuan wangxiyuan enabled auto-merge (squash) January 15, 2026 12:05
@ForBetterCodeNine ForBetterCodeNine force-pushed the qv_norm_0108 branch 2 times, most recently from 04c0afd to 99e5c64 Compare January 16, 2026 19:11
def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
# if self.ascend_compilation_config.get("fuse_norm_quant", True):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check these commented-out codes

@ForBetterCodeNine ForBetterCodeNine force-pushed the qv_norm_0108 branch 13 times, most recently from 6065042 to 1ace80e Compare January 21, 2026 01:02
Signed-off-by: cjian <2318164299@qq.com>
1
Signed-off-by: cjian <2318164299@qq.com>
@yiz-liu yiz-liu merged commit 1402cf6 into vllm-project:main Jan 22, 2026
20 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Jan 22, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (51 commits)
  [Bugfix] Remove `use_aclgraph` in mtp_proposer and use `use_cuda_graph` (vllm-project#6032)
  [BugFix] fix 3vl dense model load quant weight (vllm-project#6100)
  [CP&SP] Integrate FIA operator in mla_cp._forward_decode (vllm-project#5641)
  [CI][Doc] Upgrade wheel building's CANN to 8.5.0 and update the Docs (vllm-project#6145)
  [CI]Install clang in dokerfile for triton ascend (vllm-project#4409)
  [Main] Upgrade PTA to 2.9.0 (vllm-project#6112)
  [Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias (vllm-project#5721)
  [P/D][PCP]bugfix pcp force free twice caused logger error (vllm-project#6124)
  [BugFix]converting pa get_workspace back to capturing (vllm-project#5833)
  [CI] optimize lint term (vllm-project#5986)
  [Bugfix] Fix Triton operator usage for multimodal models based on `the mrope_interleaved` parameter (vllm-project#6042)
  [bugfix][npugraph_ex]fix the model output type issue caused by manually modify FX graph (vllm-project#6015)
  [BugFix] Support setting tp=1 for the Eagle draft model to take effect (vllm-project#6097)
  [Misc] Bump mooncake version to v0.3.8.post1 (vllm-project#6110)
  [Feature]Enable DispatchGmmCombineDecode when eagle is moe with w8a8 or not moe [RFC: issue 5476] (vllm-project#5758)
  [bugfix] adapt_remote_request_id (vllm-project#6051)
  [Feature] Add support of new W4A4_LAOS_DYNAMIC quantization method (vllm-project#5143)
  [Feature] Support DSA-CP for Hybrid scenario (vllm-project#5702)
  [CI] Upgrade CANN to 8.5.0 (vllm-project#6070)
  Default enable MLAPO (vllm-project#5952)
  ...
@ForBetterCodeNine ForBetterCodeNine deleted the qv_norm_0108 branch January 22, 2026 13:47
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…#5721)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

---------

Signed-off-by: cjian <2318164299@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…#5721)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

---------

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…#5721)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

---------

Signed-off-by: cjian <2318164299@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…#5721)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

---------

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…#5721)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2f4e654

---------

Signed-off-by: cjian <2318164299@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants