Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions vllm_ascend/compilation/passes/norm_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,200 @@ def replacement(
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)


class AddRMSNormDynamicQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight]

def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1

def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True]
)
return (
output[0],
output[3],
output[2],
)

pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)


class AddRMSNormDynamicQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
rms_norm_input, residual, rms_norm_weight, bias, self.eps
)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1

def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias
)
return (
output[0],
output[3],
output[2],
)

pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)


class AddRMSNormDynamicQuantSPPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight]

def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1

def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True]
)
out3 = output[3]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True)
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
return quantized_output, out3, output[2]

pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)


class AddRMSNormDynamicQuantSPPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
rms_norm_input, residual, rms_norm_weight, bias, self.eps
)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1

def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias
)
out3 = output[3]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True)
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
return quantized_output, out3, output[2]

pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
Comment thread
Zhang-Bryan marked this conversation as resolved.


class AddRMSNormQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
Expand All @@ -286,9 +480,13 @@ def __init__(self, vllm_config: VllmConfig):
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
if enable_custom_op():
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)

def __call__(self, graph: torch.fx.Graph):
self.begin()
Expand Down