[Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias#5721
[Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias#5721yiz-liu merged 2 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces graph fusion optimizations for QKVNormRope and QKVNormRopeWithBias. While the fusion logic itself seems correct, the mechanism for registering these fusions is fragile. It relies on a global compiler instance to pass configuration at module import time, which is an anti-pattern that can lead to silent failures and makes the code difficult to maintain. My review provides critical feedback on this design and suggests a refactoring to an explicit registration pattern, which will improve the robustness and clarity of the code.
| # Global reference to the current compiler instance | ||
| _current_compiler = None | ||
|
|
||
|
|
||
| def set_current_compiler(compiler: "AscendCompiler") -> None: | ||
| """Set the current compiler instance globally/""" | ||
| global _current_compiler | ||
| _current_compiler = compiler | ||
|
|
||
|
|
||
| def get_current_compiler() -> Optional["AscendCompiler"]: | ||
| """Get the current compiler instance globally.""" | ||
| return _current_compiler |
There was a problem hiding this comment.
The introduction of a global _current_compiler to pass configuration to other modules is a fragile design choice. It creates hidden dependencies on execution order (e.g., AscendCompiler.compute_hash must be called before fusion modules are imported) and makes the code harder to reason about and test. A more robust approach would be to pass configuration explicitly, avoiding global state. I've added a concrete refactoring suggestion in vllm_ascend/compilation/npugraph_ex_passes/add_qvnorm_rope_fusion.py that would resolve this by triggering registration explicitly from npugraph_ex_compile.
| def get_qknorm_rope_vllm_config() -> VllmConfig: | ||
| """ | ||
| Get the vllm_config from the current compiler instance. | ||
| This function is called at module load time to lazily register fusions. | ||
|
|
||
| Returns: | ||
| VllmConfig object from the compiler, or a default one if not available. | ||
| """ | ||
| try: | ||
| from vllm_ascend.compilation.compiler_interface import get_current_compiler | ||
|
|
||
| compiler = get_current_compiler() | ||
| if compiler and hasattr(compiler, 'vllm_config'): | ||
| return compiler.vllm_config | ||
| except (ImportError, AttributeError): | ||
| pass | ||
|
|
||
| # Fallback to default config | ||
| return VllmConfig() | ||
|
|
||
|
|
||
| # Lazy initialization: try to register with compiler's vllm_config if available | ||
| try: | ||
| vllm_config = get_qknorm_rope_vllm_config() | ||
| attn_layers: dict[str, Attention] = get_layers_from_vllm_config( | ||
| vllm_config, Attention) | ||
|
|
||
| if len(attn_layers) == 0: | ||
| logger.debug( | ||
| "QKNorm and Rope fusion enabled, but no Attention layers were discovered." | ||
| ) | ||
|
|
||
| layer = next(iter(attn_layers.values())) | ||
| if layer.head_size != 128: | ||
| logger.debug( | ||
| "QKNorm and Rope fusion not enabled: head_dim %d is not equal to 128", | ||
| layer.head_size) | ||
|
|
||
| # register converter for pass | ||
| common_epsilons = [1e-5, 1e-6] | ||
| for eps in common_epsilons: | ||
| logger.info( | ||
| f"Start register fusion pattern for QKNormRope with epsilons={eps}" | ||
| ) | ||
| register_qknorm_rope_fusion(layer.head_size, layer.num_heads, | ||
| layer.num_kv_heads, eps) | ||
| register_qknorm_rope_fusion_with_bias(layer.head_size, layer.num_heads, | ||
| layer.num_kv_heads, eps) | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Failed to register QKNorm and Rope fusions at module load time: {e}. " | ||
| f"This is expected if the compiler hasn't been initialized yet." | ||
| ) |
There was a problem hiding this comment.
This module-level fusion registration logic is fragile and has several issues:
- Dependency on global state: It relies on a global compiler instance being set before import. If not, it silently falls back to a default
VllmConfig, which can lead to fusions being skipped or misconfigured without a clear error. - Error masking: The broad
except Exceptionhides underlying problems during registration, logging them only at thedebuglevel. This can make debugging difficult. - Potential crash: If
attn_layersis empty,next(iter(attn_layers.values()))on line 249 will raise aStopIterationexception. This is currently masked by the broad exception but should be handled gracefully.
To make this more robust, I recommend replacing this block with an explicit registration function that is called from npugraph_ex_compile with the vllm_config. This removes the need for global state and makes dependencies clear.
_FUSIONS_REGISTERED = False
def def register_qkvnorm_rope_fusions(vllm_config: VllmConfig):
"""
Registers QKNorm+Rope fusion patterns with torchair.
This function is designed to be called explicitly once the vllm_config is available.
"""
global _FUSIONS_REGISTERED
if _FUSIONS_REGISTERED:
return
try:
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
vllm_config, Attention)
if not attn_layers:
logger.debug(
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
if layer.head_size != 128:
logger.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal to 128",
layer.head_size)
return
# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
logger.info(
f"Start register fusion pattern for QKNormRope with epsilons={eps}"
)
register_qknorm_rope_fusion(layer.head_size, layer.num_heads,
layer.num_kv_heads, eps)
register_qknorm_rope_fusion_with_bias(layer.head_size, layer.num_heads,
layer.num_kv_heads, eps)
_FUSIONS_REGISTERED = True
except Exception as e:
# Log as warning since this might be a significant issue for performance.
logger.warning(f"Failed to register QKNorm and Rope fusions: {e}")|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
370f3ad to
00b34e9
Compare
| """ | ||
| name = "AscendCompiler" | ||
|
|
||
| def compute_hash(self, vllm_config: VllmConfig) -> str: |
There was a problem hiding this comment.
No usage of compute_hash() was observed. Why add it?
| from vllm.logger import logger | ||
|
|
||
|
|
||
| def _extra_stream_scope_check(match: Match) -> bool: |
There was a problem hiding this comment.
Why not make this function public and reuse it?
| extra_check=_extra_stream_scope_check) | ||
|
|
||
|
|
||
| def get_qknorm_rope_vllm_config() -> VllmConfig: |
There was a problem hiding this comment.
This writing style looks very inappropriate. It is recommended to add a PassManager.
|
|
||
|
|
||
| # The replacement registered here will be actually executed after AOT. | ||
| def register_qknorm_rope_fusion_with_bias(head_dim, num_heads, num_kv_heads, |
There was a problem hiding this comment.
Pass should support dynamic enabling or disabling. Otherwise, some errors will be very difficult to locate.
| return True | ||
|
|
||
|
|
||
| @functools.lru_cache(None) |
1e43203 to
e80bafa
Compare
| from vllm.logger import logger | ||
|
|
||
|
|
||
| def _extra_stream_scope_check(match: Match) -> bool: |
There was a problem hiding this comment.
If the method is called by other files, the function should not start with an “_”
| # By default, we enable the graph fusion and quantization fusion pass. | ||
| self.ascend_compilation_config: dict = config.additional_config.get( | ||
| "ascend_compilation_config", {}) | ||
| if self.ascend_compilation_config.get("graphex_norm_quant", True): |
There was a problem hiding this comment.
This does not seem necessary. You can directly use the attributes in the ascend_compilation_config, such as "fuse_norm_quant"
| residual, | ||
| rms_norm_weight, | ||
| # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. | ||
| 1. / scale, |
There was a problem hiding this comment.
Please check the current code, where 1/scale is no longer used, which would introduce two redundant operators, such as reciprocal into the graph.
There was a problem hiding this comment.
All right. According to your review comments, we have modified the code to use vllm.quantize, and replaced the expression 1. / scale here with scale directly. After validation on the Qwen3-8b-w8a8 and Qwen3-235B-A22B networks, we confirmed that there are no precision issues, and the redundant operators have also been eliminated.
|
It is recommended to add an e2e test to intercept the issue of fusion failure. |
| @@ -0,0 +1,282 @@ | |||
| # | |||
There was a problem hiding this comment.
Is the file name wrong? It should be add_qknorm_rope_fusiong.py.
e80bafa to
204c570
Compare
|
LGTM. Can an e2e test be added to ensure the fusion is successful? like https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/compile/test_norm_quant_fusion.py |
OK.When the npu_graph_ex switch is enabled, we will supplement the end-to-end (e2e) test cases promptly to verify its availability. |
04c0afd to
99e5c64
Compare
| def configure(self, config: VllmConfig): | ||
| # By default, we enable the graph fusion and quantization fusion pass. | ||
| self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {}) | ||
| # if self.ascend_compilation_config.get("fuse_norm_quant", True): |
There was a problem hiding this comment.
Check these commented-out codes
99e5c64 to
e1602b8
Compare
6065042 to
1ace80e
Compare
0a35f61 to
50c72f9
Compare
…to qwen3next_rebase * 'main' of https://github.com/vllm-project/vllm-ascend: (51 commits) [Bugfix] Remove `use_aclgraph` in mtp_proposer and use `use_cuda_graph` (vllm-project#6032) [BugFix] fix 3vl dense model load quant weight (vllm-project#6100) [CP&SP] Integrate FIA operator in mla_cp._forward_decode (vllm-project#5641) [CI][Doc] Upgrade wheel building's CANN to 8.5.0 and update the Docs (vllm-project#6145) [CI]Install clang in dokerfile for triton ascend (vllm-project#4409) [Main] Upgrade PTA to 2.9.0 (vllm-project#6112) [Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias (vllm-project#5721) [P/D][PCP]bugfix pcp force free twice caused logger error (vllm-project#6124) [BugFix]converting pa get_workspace back to capturing (vllm-project#5833) [CI] optimize lint term (vllm-project#5986) [Bugfix] Fix Triton operator usage for multimodal models based on `the mrope_interleaved` parameter (vllm-project#6042) [bugfix][npugraph_ex]fix the model output type issue caused by manually modify FX graph (vllm-project#6015) [BugFix] Support setting tp=1 for the Eagle draft model to take effect (vllm-project#6097) [Misc] Bump mooncake version to v0.3.8.post1 (vllm-project#6110) [Feature]Enable DispatchGmmCombineDecode when eagle is moe with w8a8 or not moe [RFC: issue 5476] (vllm-project#5758) [bugfix] adapt_remote_request_id (vllm-project#6051) [Feature] Add support of new W4A4_LAOS_DYNAMIC quantization method (vllm-project#5143) [Feature] Support DSA-CP for Hybrid scenario (vllm-project#5702) [CI] Upgrade CANN to 8.5.0 (vllm-project#6070) Default enable MLAPO (vllm-project#5952) ...
…#5721) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: cjian <2318164299@qq.com>
…#5721) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…#5721) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: cjian <2318164299@qq.com>
…#5721) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…#5721) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: cjian <2318164299@qq.com>
What this PR does / why we need it?
This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.
For validation, we switched to the Qwen3-235B-A22B-W8A8 model for QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the RFC:#4715
Does this PR introduce any user-facing change?
No
How was this patch tested?