diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py index 94a08389256..711d804b1b5 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_allreduce_rmsnorm_fusion_pass.py @@ -21,7 +21,10 @@ from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tp_group -from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import ( + check_and_register_fusion_pass, + extra_stream_scope_check, +) # computation-communication tiling block is 512 ALLREDUCE_NORM_FUSE_THREHOLD = 512 @@ -136,8 +139,8 @@ def replacement(x, weight, residual, rms_norm_weight): class GraphEXMatmulAllReduceAddRMSNormPass: def __init__(self, vllm_config: VllmConfig): - GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register() - GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register() + check_and_register_fusion_pass(GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config) + check_and_register_fusion_pass(GraphEXLastLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config) def __call__(self, graph: torch.fx.Graph): pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py index 5c41100a1cf..1534b038195 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py @@ -22,7 +22,10 @@ from vllm.config.compilation import Range from vllm.logger import logger -from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import ( + check_and_register_fusion_pass, + extra_stream_scope_check, +) class GraphEXAddRMSNormQuantPattern: @@ -301,10 +304,10 @@ def __init__(self, vllm_config: VllmConfig): common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: - GraphEXAddRMSNormQuantPattern(vllm_config, eps=eps).register() - GraphEXAddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register() - GraphEXAddRMSNormQuantSPPattern(vllm_config, eps=eps).register() - GraphEXAddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register() + check_and_register_fusion_pass(GraphEXAddRMSNormQuantPattern, vllm_config=vllm_config, eps=eps) + check_and_register_fusion_pass(GraphEXAddRMSNormQuantPatternWithBias, vllm_config=vllm_config, eps=eps) + check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPattern, vllm_config=vllm_config, eps=eps) + check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config=vllm_config, eps=eps) def __call__(self, graph: torch.fx.Graph): pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py index 3317d132c58..8586e6d9d18 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py @@ -23,7 +23,10 @@ from vllm.config.compilation import Range from vllm.logger import logger -from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import ( + check_and_register_fusion_pass, + extra_stream_scope_check, +) class GraphEXQKNormRopeFusionPattern: @@ -202,20 +205,22 @@ def __init__(self, vllm_config: VllmConfig): if layer.head_size != 128: logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size) continue - GraphEXQKNormRopeFusionPattern( + check_and_register_fusion_pass( + GraphEXQKNormRopeFusionPattern, vllm_config=vllm_config, head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon, - ).register() - GraphEXQKNormRopeFusionPatternWithBias( + ) + check_and_register_fusion_pass( + GraphEXQKNormRopeFusionPatternWithBias, vllm_config=vllm_config, head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon, - ).register() + ) def __call__(self, graph: torch.fx.Graph): pass diff --git a/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py b/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py index 481a16ed8c9..a81dbcc9324 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py @@ -51,3 +51,25 @@ def extra_stream_scope_check(match: Match) -> bool: return False return True + + +_register_patterns = set() + + +def check_and_register_fusion_pass(pattern_class: type, **kwargs): + global _register_patterns + eps = kwargs.get("eps", 1e-6) + pattern_key = str(pattern_class.__name__) + str(eps) + if pattern_key in _register_patterns: + return + + pattern = pattern_class(**kwargs) + try: + pattern.register() + _register_patterns.add(pattern_key) + except RuntimeError as e: + if "Duplicate pattern" in str(e): + logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered") + _register_patterns.add(pattern_key) + else: + raise e