Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group

from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
check_and_register_fusion_pass,
extra_stream_scope_check,
)

# computation-communication tiling block is 512
ALLREDUCE_NORM_FUSE_THREHOLD = 512
Expand Down Expand Up @@ -136,8 +139,8 @@ def replacement(x, weight, residual, rms_norm_weight):

class GraphEXMatmulAllReduceAddRMSNormPass:
def __init__(self, vllm_config: VllmConfig):
GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
GraphEXLastLayerMatmulAllReduceAddRMSNormPattern(vllm_config).register()
check_and_register_fusion_pass(GraphEXMiddleLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config)
check_and_register_fusion_pass(GraphEXLastLayerMatmulAllReduceAddRMSNormPattern, vllm_config=vllm_config)

def __call__(self, graph: torch.fx.Graph):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from vllm.config.compilation import Range
from vllm.logger import logger

from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
check_and_register_fusion_pass,
extra_stream_scope_check,
)


class GraphEXAddRMSNormQuantPattern:
Expand Down Expand Up @@ -301,10 +304,10 @@ def __init__(self, vllm_config: VllmConfig):

common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
GraphEXAddRMSNormQuantPattern(vllm_config, eps=eps).register()
GraphEXAddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register()
GraphEXAddRMSNormQuantSPPattern(vllm_config, eps=eps).register()
GraphEXAddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register()
check_and_register_fusion_pass(GraphEXAddRMSNormQuantPattern, vllm_config=vllm_config, eps=eps)
check_and_register_fusion_pass(GraphEXAddRMSNormQuantPatternWithBias, vllm_config=vllm_config, eps=eps)
check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPattern, vllm_config=vllm_config, eps=eps)
check_and_register_fusion_pass(GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config=vllm_config, eps=eps)

def __call__(self, graph: torch.fx.Graph):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from vllm.config.compilation import Range
from vllm.logger import logger

from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check
from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import (
check_and_register_fusion_pass,
extra_stream_scope_check,
)


class GraphEXQKNormRopeFusionPattern:
Expand Down Expand Up @@ -202,20 +205,22 @@ def __init__(self, vllm_config: VllmConfig):
if layer.head_size != 128:
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
continue
GraphEXQKNormRopeFusionPattern(
check_and_register_fusion_pass(
GraphEXQKNormRopeFusionPattern,
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register()
GraphEXQKNormRopeFusionPatternWithBias(
)
check_and_register_fusion_pass(
GraphEXQKNormRopeFusionPatternWithBias,
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register()
)

def __call__(self, graph: torch.fx.Graph):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,25 @@ def extra_stream_scope_check(match: Match) -> bool:
return False

return True


_register_patterns = set()


def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _register_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _register_patterns:
return

pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e
Comment on lines +59 to +75
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a typo in the global variable name _resgister_patterns on lines 61 and 64. It should be _register_patterns. This typo causes the check for already registered patterns to always fail because it refers to a different, empty set. This defeats the purpose of this function and will not prevent duplicate registration errors. This is a critical bug.

Suggested change
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _resgister_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _resgister_patterns:
return
pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e
def check_and_register_fusion_pass(pattern_class: type, **kwargs):
global _register_patterns
eps = kwargs.get("eps", 1e-6)
pattern_key = str(pattern_class.__name__) + str(eps)
if pattern_key in _register_patterns:
return
pattern = pattern_class(**kwargs)
try:
pattern.register()
_register_patterns.add(pattern_key)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.warning(f"Pattern {pattern_class.__name__} eps {eps} has been registered")
_register_patterns.add(pattern_key)
else:
raise e