From d85bb578b39701a59e6716ad9389b8823fd31c4e Mon Sep 17 00:00:00 2001 From: cascade812 Date: Mon, 2 Jun 2025 04:38:16 +0000 Subject: [PATCH 1/6] add sp for fused rmsnorm with quantize op Signed-off-by: cascade812 --- vllm/compilation/fusion.py | 4 +- vllm/compilation/sequence_parallelism.py | 414 +++++++++++++++++++++++ 2 files changed, 416 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 618b2fe94d3a..4092bfdbeb61 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -313,8 +313,8 @@ def process(self): # 0 is always None fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} self.insert_fused_node(fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - **kwargs) + **kwargs, + epsilon=rms_node.kwargs["epsilon"]) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 17dded87fe8d..4f82bb3a88e4 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -11,6 +11,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.platforms import current_platform from .vllm_inductor_pass import VllmInductorPass @@ -234,6 +235,402 @@ def replacement( pm.fwd_only, pm_pass) +FP8_DTYPE = current_platform.fp8_dtype() + + +class EmbeddingAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, \ + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], device=self.device, \ + dtype=self.dtype) + result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [arg2_1, mul_6, unsqueeze, full_default, result, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = tensor_model_parallel_all_reduce(where) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm_static_fp8_quant.default, + result=result, + input=all_reduce, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) + + return rmsnorm[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + where, dim=0, world_size=tp_size, group_name=tp.unique_name) + + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=result.dtype) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm_static_fp8_quant.default, + result=rmsnorm_result, + input=reduce_scatter, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=result, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + scale=scale, + epsilon=self.epsilon, + ) + + return rmsnorm[1], rmsnorm[2] + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=half_result, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + scale=scale, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8PatternExtra(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + mm_1, + residual, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + epsilon=self.epsilon, + result=result, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + scale=scale, + ) + + return rmsnorm[1], rmsnorm[2] + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + epsilon=self.epsilon, + result=half_result, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + scale=scale, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=result, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + scale=scale, + epsilon=self.epsilon, + ) + + return rmsnorm[1] + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=half_result, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + scale=scale, + epsilon=self.epsilon, + ) + + normalized = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormStaticFP8PatternExtra(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + mm_1, + residual, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + epsilon=self.epsilon, + result=result, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + scale=scale, + ) + + return rmsnorm[1] + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + epsilon=self.epsilon, + result=half_result, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + scale=scale, + ) + + normalized = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + class SequenceParallelismPass(VllmInductorPass): def __init__(self, config: VllmConfig): @@ -241,6 +638,7 @@ def __init__(self, config: VllmConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: EmbeddingAllReduceRMSNormPattern( epsilon, self.model_dtype, self.device).register(self.patterns) @@ -250,6 +648,16 @@ def __init__(self, config: VllmConfig): LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) + + EmbeddingAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + MiddleAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + LastAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() @@ -260,8 +668,14 @@ def is_applicable_for_shape(self, shape: Optional[int]) -> bool: def __call__(self, graph: fx.Graph): self.begin() + # if get_tp_group().rank == 0: + # # if dist.get_rank(get_tp_group()) == 0: + # print(f"before sp graph {graph}") self.dump_graph(graph, "before_sequence_parallelism_pass") count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_sequence_parallelism_pass") + # if get_tp_group().rank == 0: + # if dist.get_rank(get_tp_group()) == 0: + # print(f"after sp graph {graph}") self.end_and_log() From 5a006d74e360f3c25a4008bfcdf566bc73af771d Mon Sep 17 00:00:00 2001 From: cascade812 Date: Mon, 2 Jun 2025 20:55:57 +0000 Subject: [PATCH 2/6] add sq pass for rms + quant Signed-off-by: cascade812 --- vllm/compilation/sequence_parallelism.py | 225 +++++++++++++++++++---- 1 file changed, 185 insertions(+), 40 deletions(-) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 4f82bb3a88e4..a4186e6b79d0 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -26,6 +26,16 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str): self.device = device +class AllReduceRMSNormQuantPattern: + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + self.epsilon = epsilon + self.dtype = dtype + self.device = device + self.op = op + + class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): def get_inputs(self): @@ -238,7 +248,7 @@ def replacement( FP8_DTYPE = current_platform.fp8_dtype() -class EmbeddingAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): +class EmbeddingAllReduceFusedRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): def get_inputs(self): arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) @@ -321,7 +331,7 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): +class MiddleAllReduceFusedRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -376,11 +386,11 @@ def replacement( reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rs_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=half_result, + result=rs_result, input=reduce_scatter, residual=residual, weight=rms_norm_weights, @@ -399,7 +409,7 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormStaticFP8PatternExtra(AllReduceRMSNormPattern): +class LastAllReduceFusedRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -413,8 +423,8 @@ def get_inputs(self): return [ result, - mm_1, residual, + mm_1, rms_norm_weights, scale, ] @@ -432,15 +442,15 @@ def pattern( rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - epsilon=self.epsilon, result=result, input=all_reduce, residual=residual, weight=rms_norm_weights, scale=scale, + epsilon=self.epsilon, ) - return rmsnorm[1], rmsnorm[2] + return rmsnorm[1] def replacement( result: torch.Tensor, @@ -454,29 +464,124 @@ def replacement( reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rs_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - epsilon=self.epsilon, - result=half_result, + result=rs_result, input=reduce_scatter, residual=residual, weight=rms_norm_weights, scale=scale, + epsilon=self.epsilon, ) - all_gather = torch.ops.vllm.all_gather.default( + normalized = torch.ops.vllm.all_gather.default( rmsnorm[1], dim=0, world_size=tp_size, group_name=tp.unique_name) - return all_gather, rmsnorm[2] + + return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): +class EmbeddingAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormQuantPattern): + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, \ + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], device=self.device, \ + dtype=self.dtype) + result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [arg2_1, mul_6, unsqueeze, full_default, result, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = tensor_model_parallel_all_reduce(where) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + input=all_reduce, + weight=weight, + epsilon=self.epsilon, + ) + + static_fp8 = torch.ops.higher_order.auto_functionalized( + # torch.ops._C.static_scaled_fp8_quant.default, + self.op, + result=result, + input=rmsnorm[1], + scale=scale, + ) + + return static_fp8, all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + where, dim=0, world_size=tp_size, group_name=tp.unique_name) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + input=reduce_scatter, + weight=weight, + epsilon=self.epsilon, + ) + + quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) + static_fp8 = torch.ops.higher_order.auto_functionalized( + # torch.ops._C.static_scaled_fp8_quant.default, + self.op, + result=quant_result, + input=rmsnorm[1], + scale=scale, + ) + + all_gather = torch.ops.vllm.all_gather.default( + static_fp8, + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormQuantPattern): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -508,16 +613,22 @@ def pattern( all_reduce = tensor_model_parallel_all_reduce(mm_1) rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=result, + torch.ops._C.fused_add_rms_norm.default, input=all_reduce, residual=residual, weight=rms_norm_weights, - scale=scale, epsilon=self.epsilon, ) - return rmsnorm[1] + static_fp8 = torch.ops.higher_order.auto_functionalized( + # torch.ops._C.static_scaled_fp8_quant.default, + self.op, + result=result, + input=rmsnorm[1], + scale=scale, + ) + + return static_fp8, rmsnorm[2] def replacement( result: torch.Tensor, @@ -531,30 +642,35 @@ def replacement( reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=half_result, + torch.ops._C.fused_add_rms_norm.default, input=reduce_scatter, residual=residual, weight=rms_norm_weights, - scale=scale, epsilon=self.epsilon, ) - normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], + quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) + static_fp8 = torch.ops.higher_order.auto_functionalized( + # torch.ops._C.static_scaled_fp8_quant.default, + self.op, + result=quant_result, + input=rmsnorm[1], + scale=scale, + ) + + all_gather = torch.ops.vllm.all_gather.default( + static_fp8, dim=0, world_size=tp_size, group_name=tp.unique_name) - - return normalized + return all_gather, rmsnorm[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormStaticFP8PatternExtra(AllReduceRMSNormPattern): +class LastAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormQuantPattern): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -568,8 +684,8 @@ def get_inputs(self): return [ result, - mm_1, residual, + mm_1, rms_norm_weights, scale, ] @@ -586,16 +702,22 @@ def pattern( all_reduce = tensor_model_parallel_all_reduce(mm_1) rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - epsilon=self.epsilon, - result=result, + torch.ops._C.fused_add_rms_norm.default, input=all_reduce, residual=residual, weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + static_fp8 = torch.ops.higher_order.auto_functionalized( + # torch.ops._C.static_scaled_fp8_quant.default, + self.op, + result=result, + input=rmsnorm[1], scale=scale, ) - return rmsnorm[1] + return static_fp8 def replacement( result: torch.Tensor, @@ -608,19 +730,26 @@ def replacement( tp_size = get_tensor_model_parallel_world_size() reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - half_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - epsilon=self.epsilon, - result=half_result, + torch.ops._C.fused_add_rms_norm.default, input=reduce_scatter, residual=residual, weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) + static_fp8 = torch.ops.higher_order.auto_functionalized( + # torch.ops._C.static_scaled_fp8_quant.default, + self.op, + result=quant_result, + input=rmsnorm[1], scale=scale, ) normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], + static_fp8, dim=0, world_size=tp_size, group_name=tp.unique_name) @@ -640,6 +769,7 @@ def __init__(self, config: VllmConfig): pass_name="sequence_parallelism_pass") for epsilon in [1e-5, 1e-6]: + # Normal RMSNorm patterns EmbeddingAllReduceRMSNormPattern( epsilon, self.model_dtype, self.device).register(self.patterns) @@ -649,15 +779,30 @@ def __init__(self, config: VllmConfig): LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) - EmbeddingAllReduceRMSNormStaticFP8Pattern( + # Fused RMSNorm + Static FP8 patterns + EmbeddingAllReduceFusedRMSNormStaticFP8Pattern( epsilon, self.model_dtype, self.device).register(self.patterns) - MiddleAllReduceRMSNormStaticFP8Pattern( + MiddleAllReduceFusedRMSNormStaticFP8Pattern( epsilon, self.model_dtype, self.device).register(self.patterns) - LastAllReduceRMSNormStaticFP8Pattern( + LastAllReduceFusedRMSNormStaticFP8Pattern( epsilon, self.model_dtype, self.device).register(self.patterns) + # RMSNorm + Static FP8 quantization patterns + EmbeddingAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + torch.ops._C.static_scaled_fp8_quant.default).register( + self.patterns) + MiddleAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + torch.ops._C.static_scaled_fp8_quant.default).register( + self.patterns) + LastAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + torch.ops._C.static_scaled_fp8_quant.default).register( + self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() From 8c017382b4d73fedc1d36d8f607f906b02967767 Mon Sep 17 00:00:00 2001 From: cascade812 Date: Tue, 3 Jun 2025 03:22:16 +0000 Subject: [PATCH 3/6] add tests Signed-off-by: cascade812 --- tests/compile/test_sequence_parallelism.py | 57 +++++++++++++++++++++- vllm/compilation/sequence_parallelism.py | 24 ++++++--- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 2cd7ebaacec0..bbffab1a6317 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -78,6 +78,61 @@ def ops_in_model(self): return [torch.ops._C.fused_add_rms_norm.default] +class TestQuantModel(torch.nn.Module): + + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size))) + self.norm = RMSNorm(hidden_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + norm_output = norm_output.reshape(-1, self.hidden_size) + static_scaled_fp8_quant = \ + torch.ops.vllm.static_scaled_fp8_quant.default( + norm_output, dtype=torch.float16) + + return static_scaled_fp8_quant, residual_output + + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default + ] + + def ops_in_model(self): + return [torch.ops._C.fused_add_rms_norm.default] + + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @@ -145,7 +200,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(sequence_parallelism_pass, func_pass) - model = TestModel(hidden_size, hidden_size * 2) + model = TestQuantModel(hidden_size, hidden_size * 2) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index a4186e6b79d0..d70b4ac5f408 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -525,11 +525,13 @@ def pattern( epsilon=self.epsilon, ) + reshape = torch.ops.aten.reshape.default( + rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=result, - input=rmsnorm[1], + input=reshape, scale=scale, ) @@ -560,12 +562,14 @@ def replacement( epsilon=self.epsilon, ) + reshape = torch.ops.aten.reshape.default( + rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) static_fp8 = torch.ops.higher_order.auto_functionalized( # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=quant_result, - input=rmsnorm[1], + input=reshape, scale=scale, ) @@ -620,11 +624,13 @@ def pattern( epsilon=self.epsilon, ) + reshape = torch.ops.aten.reshape.default( + rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=result, - input=rmsnorm[1], + input=reshape, scale=scale, ) @@ -651,11 +657,13 @@ def replacement( ) quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) + reshape = torch.ops.aten.reshape.default( + rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=quant_result, - input=rmsnorm[1], + input=reshape, scale=scale, ) @@ -709,11 +717,13 @@ def pattern( epsilon=self.epsilon, ) + reshape = torch.ops.aten.reshape.default( + rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=result, - input=rmsnorm[1], + input=reshape, scale=scale, ) @@ -740,11 +750,13 @@ def replacement( ) quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) + reshape = torch.ops.aten.reshape.default( + rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=quant_result, - input=rmsnorm[1], + input=reshape, scale=scale, ) From a5a5ee6a3c97d532044bef29d5575da84d556905 Mon Sep 17 00:00:00 2001 From: cascade812 Date: Wed, 4 Jun 2025 18:59:09 +0000 Subject: [PATCH 4/6] update Signed-off-by: cascade812 --- tests/compile/test_sequence_parallelism.py | 33 ++++-- vllm/compilation/sequence_parallelism.py | 125 ++++++++++----------- vllm/config.py | 2 +- 3 files changed, 82 insertions(+), 78 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index bbffab1a6317..54bb7cfb3021 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -84,11 +84,14 @@ def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size))) + self.gate_proj = torch.nn.Parameter(torch.empty( + (intermediate_size, hidden_size)), + requires_grad=False) self.norm = RMSNorm(hidden_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) + # Register scale as a buffer + self.register_buffer('scale', torch.tensor(1.0, dtype=torch.float32)) def forward(self, hidden_states, residual): """ @@ -113,12 +116,16 @@ def forward(self, hidden_states, residual): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) - norm_output = norm_output.reshape(-1, self.hidden_size) - static_scaled_fp8_quant = \ - torch.ops.vllm.static_scaled_fp8_quant.default( - norm_output, dtype=torch.float16) - return static_scaled_fp8_quant, residual_output + scale = self.scale.to(norm_output.device) + quant_result = torch.empty(norm_output.shape, + dtype=current_platform.fp8_dtype()) + + torch.ops._C.static_scaled_fp8_quant.default(result=quant_result, + input=norm_output, + scale=scale) + + return quant_result, residual_output def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -134,13 +141,15 @@ def ops_in_model(self): @multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("test_model", [TestModel, TestQuantModel]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -def test_sequence_parallelism_pass(batch_size: int, seq_len: int, +def test_sequence_parallelism_pass(test_model: torch.nn.Module, + batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): num_processes = 2 @@ -148,14 +157,16 @@ def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda torch.multiprocessing.spawn(fn, - args=(num_processes, batch_size, seq_len, - hidden_size, dtype), + args=(num_processes, test_model, + batch_size, seq_len, hidden_size, + dtype), nprocs=nprocs) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, + test_model: torch.nn.Module, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): @@ -200,7 +211,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(sequence_parallelism_pass, func_pass) - model = TestQuantModel(hidden_size, hidden_size * 2) + model = test_model(hidden_size, hidden_size * 2) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index d70b4ac5f408..88ace9613e30 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -386,7 +386,7 @@ def replacement( reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - rs_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rs_result = torch.empty_like(reduce_scatter, dtype=result.dtype) rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, @@ -464,7 +464,7 @@ def replacement( reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - rs_result = torch.empty_like(reduce_scatter, dtype=FP8_DTYPE) + rs_result = torch.empty_like(reduce_scatter, dtype=result.dtype) rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, result=rs_result, @@ -498,51 +498,59 @@ def get_inputs(self): dtype=self.dtype) > 0.5 full_default = torch.zeros([1, 8, 4], device=self.device, \ dtype=self.dtype) - result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [arg2_1, mul_6, unsqueeze, full_default, result, weight, scale] + return [ + arg2_1, mul_6, unsqueeze, full_default, rmsnorm_result, + quant_result, weight, scale + ] def register(self, pm_pass: PatternMatcherPass): def pattern( arg2_1: torch.Tensor, - mul_6: torch.Tensor, + mul: torch.Tensor, unsqueeze: torch.Tensor, full_default: torch.Tensor, - result: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): - embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + embedding = torch.ops.aten.embedding.default(arg2_1, mul) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) all_reduce = tensor_model_parallel_all_reduce(where) rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.rms_norm.default, + result=rmsnorm_result, input=all_reduce, weight=weight, epsilon=self.epsilon, ) - reshape = torch.ops.aten.reshape.default( - rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( - # torch.ops._C.static_scaled_fp8_quant.default, self.op, - result=result, - input=reshape, + result=quant_result, + input=rmsnorm[1], scale=scale, ) - return static_fp8, all_reduce + return static_fp8[1], all_reduce def replacement( arg2_1: torch.Tensor, mul_6: torch.Tensor, unsqueeze: torch.Tensor, full_default: torch.Tensor, - result: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): @@ -555,26 +563,27 @@ def replacement( reduce_scatter = torch.ops.vllm.reduce_scatter.default( where, dim=0, world_size=tp_size, group_name=tp.unique_name) + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=rmsnorm_result.dtype) rmsnorm = torch.ops.higher_order.auto_functionalized( torch.ops._C.rms_norm.default, + result=rmsnorm_result, input=reduce_scatter, weight=weight, epsilon=self.epsilon, ) - reshape = torch.ops.aten.reshape.default( - rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) - quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) + quant_result = torch.empty_like(rmsnorm[1], + dtype=quant_result.dtype) static_fp8 = torch.ops.higher_order.auto_functionalized( - # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=quant_result, - input=reshape, + input=rmsnorm[1], scale=scale, ) all_gather = torch.ops.vllm.all_gather.default( - static_fp8, + static_fp8[1], dim=0, world_size=tp_size, group_name=tp.unique_name) @@ -624,17 +633,14 @@ def pattern( epsilon=self.epsilon, ) - reshape = torch.ops.aten.reshape.default( - rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( - # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=result, - input=reshape, + input=rmsnorm[1], scale=scale, ) - return static_fp8, rmsnorm[2] + return static_fp8[1], rmsnorm[2] def replacement( result: torch.Tensor, @@ -657,18 +663,15 @@ def replacement( ) quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) - reshape = torch.ops.aten.reshape.default( - rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( - # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=quant_result, - input=reshape, + input=rmsnorm[1], scale=scale, ) all_gather = torch.ops.vllm.all_gather.default( - static_fp8, + static_fp8[1], dim=0, world_size=tp_size, group_name=tp.unique_name) @@ -717,17 +720,14 @@ def pattern( epsilon=self.epsilon, ) - reshape = torch.ops.aten.reshape.default( - rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( - # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=result, - input=reshape, + input=rmsnorm[1], scale=scale, ) - return static_fp8 + return static_fp8[1] def replacement( result: torch.Tensor, @@ -750,18 +750,15 @@ def replacement( ) quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) - reshape = torch.ops.aten.reshape.default( - rmsnorm[1], [-1, rmsnorm[1].shape[-1]]) static_fp8 = torch.ops.higher_order.auto_functionalized( - # torch.ops._C.static_scaled_fp8_quant.default, self.op, result=quant_result, - input=reshape, + input=rmsnorm[1], scale=scale, ) normalized = torch.ops.vllm.all_gather.default( - static_fp8, + static_fp8[1], dim=0, world_size=tp_size, group_name=tp.unique_name) @@ -781,15 +778,17 @@ def __init__(self, config: VllmConfig): pass_name="sequence_parallelism_pass") for epsilon in [1e-5, 1e-6]: - # Normal RMSNorm patterns - EmbeddingAllReduceRMSNormPattern( - epsilon, self.model_dtype, self.device).register(self.patterns) - - MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) - - LastAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) + # RMSNorm + Static FP8 quantization patterns + fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default + EmbeddingAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + MiddleAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + LastAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) # Fused RMSNorm + Static FP8 patterns EmbeddingAllReduceFusedRMSNormStaticFP8Pattern( @@ -801,19 +800,15 @@ def __init__(self, config: VllmConfig): LastAllReduceFusedRMSNormStaticFP8Pattern( epsilon, self.model_dtype, self.device).register(self.patterns) - # RMSNorm + Static FP8 quantization patterns - EmbeddingAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - torch.ops._C.static_scaled_fp8_quant.default).register( - self.patterns) - MiddleAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - torch.ops._C.static_scaled_fp8_quant.default).register( - self.patterns) - LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - torch.ops._C.static_scaled_fp8_quant.default).register( - self.patterns) + # Normal RMSNorm patterns + EmbeddingAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) + + LastAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. @@ -826,13 +821,11 @@ def is_applicable_for_shape(self, shape: Optional[int]) -> bool: def __call__(self, graph: fx.Graph): self.begin() # if get_tp_group().rank == 0: - # # if dist.get_rank(get_tp_group()) == 0: # print(f"before sp graph {graph}") self.dump_graph(graph, "before_sequence_parallelism_pass") count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_sequence_parallelism_pass") - # if get_tp_group().rank == 0: - # if dist.get_rank(get_tp_group()) == 0: - # print(f"after sp graph {graph}") + if get_tp_group().rank == 0: + print(f"Replaced {count} patterns, after sp graph {graph}") self.end_and_log() diff --git a/vllm/config.py b/vllm/config.py index 6cec97a5f11b..b363911d49d9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4321,7 +4321,7 @@ def __post_init__(self): self.compilation_config.use_cudagraph = True self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False - self.compilation_config.pass_config.enable_noop = False + # self.compilation_config.pass_config.enable_noop = False self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() From 184764d873bb6ac3d16f0e753aba539cde14371d Mon Sep 17 00:00:00 2001 From: cascade812 Date: Wed, 4 Jun 2025 22:10:57 +0000 Subject: [PATCH 5/6] fix and remove debug line Signed-off-by: cascade812 --- tests/compile/test_sequence_parallelism.py | 4 ++-- vllm/compilation/sequence_parallelism.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 54bb7cfb3021..485da89ffeb2 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -148,7 +148,7 @@ def ops_in_model(self): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -def test_sequence_parallelism_pass(test_model: torch.nn.Module, +def test_sequence_parallelism_pass(test_model: type[torch.nn.Module], batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): num_processes = 2 @@ -166,7 +166,7 @@ def run_torch_spawn(fn, nprocs): def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, - test_model: torch.nn.Module, + test_model: type[torch.nn.Module], batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 88ace9613e30..b6a4c99e7fd0 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -515,7 +515,7 @@ def register(self, pm_pass: PatternMatcherPass): def pattern( arg2_1: torch.Tensor, - mul: torch.Tensor, + mul_6: torch.Tensor, unsqueeze: torch.Tensor, full_default: torch.Tensor, rmsnorm_result: torch.Tensor, @@ -523,7 +523,7 @@ def pattern( weight: torch.Tensor, scale: torch.Tensor, ): - embedding = torch.ops.aten.embedding.default(arg2_1, mul) + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) all_reduce = tensor_model_parallel_all_reduce(where) @@ -820,12 +820,8 @@ def is_applicable_for_shape(self, shape: Optional[int]) -> bool: def __call__(self, graph: fx.Graph): self.begin() - # if get_tp_group().rank == 0: - # print(f"before sp graph {graph}") self.dump_graph(graph, "before_sequence_parallelism_pass") count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_sequence_parallelism_pass") - if get_tp_group().rank == 0: - print(f"Replaced {count} patterns, after sp graph {graph}") self.end_and_log() From 7f19b80d0967a797ead20aee505c4ba06c76d10f Mon Sep 17 00:00:00 2001 From: cascade812 Date: Thu, 5 Jun 2025 04:27:54 +0000 Subject: [PATCH 6/6] update test and address comment Signed-off-by: cascade812 --- tests/compile/test_sequence_parallelism.py | 6 +- tests/distributed/test_sequence_parallel.py | 108 ++-- vllm/compilation/sequence_parallelism.py | 596 +++++++++----------- vllm/config.py | 2 +- 4 files changed, 309 insertions(+), 403 deletions(-) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 485da89ffeb2..474147854b26 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -197,10 +197,10 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model, + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, task="auto", - tokenizer=model, + tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=True, dtype=dtype, diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index c9eba2b43788..c644d1e49e8e 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -27,7 +27,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int - sp_enabled: bool + enable_fusion: bool eager_mode: bool chunked_prefill: bool @@ -66,49 +66,18 @@ def detailed( task: TaskOption = "auto", load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True) - ], + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -125,19 +94,44 @@ def fast( multi_node_only: bool = False, load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ + parallel_setups=parallel_setups, + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fp8_quant( + *, + tp_base: int = 2, + pp_base: int = 1, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + parallel_setups = [] + for fusion_val in [False, True]: + parallel_setups.append( ParallelSetup(tp_size=tp_base, pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ], + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False)) + return SPTestSettings( + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -170,7 +164,7 @@ def _compare_sp( ( tp_size, pp_size, - sp_enabled, + enable_fusion, eager_mode, chunked_prefill, ) = parallel_setup @@ -239,9 +233,9 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallelism': sp_enabled, + 'enable_sequence_parallelism': True, + 'enable_fusion': enable_fusion, 'enable_noop': True, - 'enable_fusion': True, }, } @@ -290,12 +284,14 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(), } SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ] diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index b6a4c99e7fd0..d22ff0a1a97f 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -18,35 +18,141 @@ logger = init_logger(__name__) -class AllReduceRMSNormPattern: +class _SequenceParallelPatternHelper: + """Base helper for sequence parallelism patterns.""" def __init__(self, epsilon: float, dtype: torch.dtype, device: str): self.epsilon = epsilon self.dtype = dtype self.device = device - - -class AllReduceRMSNormQuantPattern: + self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return tensor_model_parallel_all_reduce(x) + + def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.reduce_scatter.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + def _all_gather(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + +class _RMSNormOpHelper(_SequenceParallelPatternHelper): + """Helper for RMSNorm operations in sequence parallelism patterns.""" + + def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, + weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + +class _RMSNormQuantOpHelper(_SequenceParallelPatternHelper): + """Helper for RMSNorm + Quantization operations in sequence parallelism patterns.""" # noqa: E501 def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): - self.epsilon = epsilon - self.dtype = dtype - self.device = device - self.op = op - + quant_op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device) + self.quant_op = quant_op + + def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, + quant_result_buffer, input_tensor, + weight_tensor, scale_tensor): + rmsnorm_out_tuple = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=rmsnorm_result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple + + def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, + input_tensor, residual_tensor, + weight_tensor, scale_tensor): + fused_add_rmsnorm_out_tuple = torch.ops.higher_order.auto_functionalized( # noqa: E501 + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] + + +class _FusedRMSNormQuantOpHelper(_SequenceParallelPatternHelper): + """Helper for Fused RMSNorm + Quantization operations in sequence parallelism patterns.""" # noqa: E501 -class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + fused_rmsnorm_quant_op: torch._ops.OpOverload, + fused_add_rmsnorm_quant_op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device) + self.fused_rmsnorm_quant_op = fused_rmsnorm_quant_op + self.fused_add_rmsnorm_quant_op = fused_add_rmsnorm_quant_op + + def _functional_fused_rmsnorm_quant(self, result_buffer, input_tensor, + weight_tensor, scale_tensor): + return torch.ops.higher_order.auto_functionalized( + self.fused_rmsnorm_quant_op, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + scale=scale_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm_quant(self, result_buffer, input_tensor, + residual_tensor, weight_tensor, + scale_tensor): + return torch.ops.higher_order.auto_functionalized( + self.fused_add_rmsnorm_quant_op, + result=result_buffer, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + scale=scale_tensor, + epsilon=self.epsilon) + + +class EmbeddingAllReduceRMSNormPattern(_RMSNormOpHelper): def get_inputs(self): arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], device=self.device, dtype=torch.long) - unsqueeze = torch.rand([1, 8, 1], device=self.device, \ - dtype=self.dtype) > 0.5 - full_default = torch.zeros([1, 8, 4], device=self.device, \ - dtype=self.dtype) + unsqueeze = torch.rand([1, 8, 1], device=self.device, + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) @@ -65,14 +171,8 @@ def pattern( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - all_reduce = tensor_model_parallel_all_reduce(where) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=permute, - input=all_reduce, - weight=arg3_1, - epsilon=self.epsilon, - ) + all_reduce = self._all_reduce(where) + rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) return rmsnorm[1], all_reduce @@ -87,26 +187,13 @@ def replacement( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - where, dim=0, world_size=tp_size, group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(where) rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=rmsnorm_result, - input=reduce_scatter, - weight=arg3_1, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, + arg3_1) + + all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter @@ -114,7 +201,7 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class MiddleAllReduceRMSNormPattern(_RMSNormOpHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -137,16 +224,9 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) return rmsnorm[1], rmsnorm[2] def replacement( @@ -154,32 +234,17 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class LastAllReduceRMSNormPattern(_RMSNormOpHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -202,16 +267,9 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) return rmsnorm[1] def replacement( @@ -219,26 +277,10 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) - + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + normalized = self._all_gather(rmsnorm[1]) return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -248,17 +290,28 @@ def replacement( FP8_DTYPE = current_platform.fp8_dtype() -class EmbeddingAllReduceFusedRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): +class EmbeddingAllReduceFusedRMSNormStaticFP8Pattern(_FusedRMSNormQuantOpHelper + ): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, + dtype, + device, + fused_rmsnorm_quant_op=torch.ops._C. + rms_norm_static_fp8_quant.default, + fused_add_rmsnorm_quant_op=torch.ops._C. + fused_add_rms_norm_static_fp8_quant.default) def get_inputs(self): arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], device=self.device, dtype=torch.long) - unsqueeze = torch.rand([1, 8, 1], device=self.device, \ - dtype=self.dtype) > 0.5 - full_default = torch.zeros([1, 8, 4], device=self.device, \ - dtype=self.dtype) + unsqueeze = torch.rand([1, 8, 1], device=self.device, + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) @@ -278,16 +331,9 @@ def pattern( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - all_reduce = tensor_model_parallel_all_reduce(where) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm_static_fp8_quant.default, - result=result, - input=all_reduce, - weight=weight, - scale=scale, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(where) + rmsnorm = self._functional_fused_rmsnorm_quant( + result, all_reduce, weight, scale) return rmsnorm[1], all_reduce def replacement( @@ -302,28 +348,13 @@ def replacement( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - where, dim=0, world_size=tp_size, group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(where) rmsnorm_result = torch.empty_like(reduce_scatter, dtype=result.dtype) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm_static_fp8_quant.default, - result=rmsnorm_result, - input=reduce_scatter, - weight=weight, - scale=scale, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + rmsnorm = self._functional_fused_rmsnorm_quant( + rmsnorm_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter @@ -331,7 +362,16 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceFusedRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): +class MiddleAllReduceFusedRMSNormStaticFP8Pattern(_FusedRMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, + dtype, + device, + fused_rmsnorm_quant_op=torch.ops._C. + rms_norm_static_fp8_quant.default, + fused_add_rmsnorm_quant_op=torch.ops._C. + fused_add_rms_norm_static_fp8_quant.default) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -360,18 +400,9 @@ def pattern( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=result, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - scale=scale, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + result, all_reduce, residual, rms_norm_weights, scale) return rmsnorm[1], rmsnorm[2] def replacement( @@ -381,35 +412,27 @@ def replacement( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - + reduce_scatter = self._reduce_scatter(mm_1) rs_result = torch.empty_like(reduce_scatter, dtype=result.dtype) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=rs_result, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - scale=scale, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + rs_result, reduce_scatter, residual, rms_norm_weights, scale) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceFusedRMSNormStaticFP8Pattern(AllReduceRMSNormPattern): +class LastAllReduceFusedRMSNormStaticFP8Pattern(_FusedRMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, + dtype, + device, + fused_rmsnorm_quant_op=torch.ops._C. + rms_norm_static_fp8_quant.default, + fused_add_rmsnorm_quant_op=torch.ops._C. + fused_add_rms_norm_static_fp8_quant.default) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -438,18 +461,9 @@ def pattern( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=result, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - scale=scale, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + result, all_reduce, residual, rms_norm_weights, scale) return rmsnorm[1] def replacement( @@ -459,45 +473,33 @@ def replacement( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - + reduce_scatter = self._reduce_scatter(mm_1) rs_result = torch.empty_like(reduce_scatter, dtype=result.dtype) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=rs_result, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - scale=scale, - epsilon=self.epsilon, - ) - - normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) - + rmsnorm = self._functional_fused_add_rmsnorm_quant( + rs_result, reduce_scatter, residual, rms_norm_weights, scale) + normalized = self._all_gather(rmsnorm[1]) return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class EmbeddingAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormQuantPattern): +class EmbeddingAllReduceRMSNormStaticFP8Pattern(_RMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], device=self.device, dtype=torch.long) - unsqueeze = torch.rand([1, 8, 1], device=self.device, \ - dtype=self.dtype) > 0.5 - full_default = torch.zeros([1, 8, 4], device=self.device, \ - dtype=self.dtype) + unsqueeze = torch.rand([1, 8, 1], device=self.device, + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) @@ -526,22 +528,9 @@ def pattern( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - all_reduce = tensor_model_parallel_all_reduce(where) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - - static_fp8 = torch.ops.higher_order.auto_functionalized( - self.op, - result=quant_result, - input=rmsnorm[1], - scale=scale, - ) - + all_reduce = self._all_reduce(where) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, all_reduce, weight, scale) return static_fp8[1], all_reduce def replacement( @@ -557,36 +546,16 @@ def replacement( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - where, dim=0, world_size=tp_size, group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(where) rmsnorm_result = torch.empty_like(reduce_scatter, dtype=rmsnorm_result.dtype) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=rmsnorm_result, - input=reduce_scatter, - weight=weight, - epsilon=self.epsilon, - ) - - quant_result = torch.empty_like(rmsnorm[1], - dtype=quant_result.dtype) - static_fp8 = torch.ops.higher_order.auto_functionalized( - self.op, - result=quant_result, - input=rmsnorm[1], - scale=scale, - ) - - all_gather = torch.ops.vllm.all_gather.default( - static_fp8[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + quant_result = torch.empty_like( + rmsnorm_result, # Output of RMSNorm + dtype=quant_result.dtype) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(static_fp8[1]) return all_gather, reduce_scatter @@ -594,7 +563,11 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormQuantPattern): +class MiddleAllReduceRMSNormStaticFP8Pattern(_RMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -623,24 +596,10 @@ def pattern( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - static_fp8 = torch.ops.higher_order.auto_functionalized( - self.op, - result=result, - input=rmsnorm[1], - scale=scale, - ) - - return static_fp8[1], rmsnorm[2] + all_reduce = self._all_reduce(mm_1) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1], rmsnorm_residual_out def replacement( result: torch.Tensor, @@ -649,39 +608,24 @@ def replacement( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) - static_fp8 = torch.ops.higher_order.auto_functionalized( - self.op, - result=quant_result, - input=rmsnorm[1], - scale=scale, - ) - - all_gather = torch.ops.vllm.all_gather.default( - static_fp8[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) - return all_gather, rmsnorm[2] + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + all_gather = self._all_gather(static_fp8[1]) + return all_gather, rmsnorm_residual_out pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormStaticFP8Pattern(AllReduceRMSNormQuantPattern): +class LastAllReduceRMSNormStaticFP8Pattern(_RMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -710,23 +654,9 @@ def pattern( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - static_fp8 = torch.ops.higher_order.auto_functionalized( - self.op, - result=result, - input=rmsnorm[1], - scale=scale, - ) - + all_reduce = self._all_reduce(mm_1) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + result, all_reduce, residual, rms_norm_weights, scale) return static_fp8[1] def replacement( @@ -736,33 +666,13 @@ def replacement( rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - quant_result = torch.empty_like(rmsnorm[1], dtype=result.dtype) - static_fp8 = torch.ops.higher_order.auto_functionalized( - self.op, - result=quant_result, - input=rmsnorm[1], - scale=scale, - ) - - normalized = torch.ops.vllm.all_gather.default( - static_fp8[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) - + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + normalized = self._all_gather(static_fp8[1]) return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), diff --git a/vllm/config.py b/vllm/config.py index b363911d49d9..04a500a9a86a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4320,7 +4320,7 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.pass_config.enable_fusion = False + # self.compilation_config.pass_config.enable_fusion = False # self.compilation_config.pass_config.enable_noop = False self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1()