diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 2cd7ebaacec0..474147854b26 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -78,14 +78,78 @@ def ops_in_model(self): return [torch.ops._C.fused_add_rms_norm.default] +class TestQuantModel(torch.nn.Module): + + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = torch.nn.Parameter(torch.empty( + (intermediate_size, hidden_size)), + requires_grad=False) + self.norm = RMSNorm(hidden_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + # Register scale as a buffer + self.register_buffer('scale', torch.tensor(1.0, dtype=torch.float32)) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + + scale = self.scale.to(norm_output.device) + quant_result = torch.empty(norm_output.shape, + dtype=current_platform.fp8_dtype()) + + torch.ops._C.static_scaled_fp8_quant.default(result=quant_result, + input=norm_output, + scale=scale) + + return quant_result, residual_output + + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default + ] + + def ops_in_model(self): + return [torch.ops._C.fused_add_rms_norm.default] + + @multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("test_model", [TestModel, TestQuantModel]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -def test_sequence_parallelism_pass(batch_size: int, seq_len: int, +def test_sequence_parallelism_pass(test_model: type[torch.nn.Module], + batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): num_processes = 2 @@ -93,14 +157,16 @@ def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda torch.multiprocessing.spawn(fn, - args=(num_processes, batch_size, seq_len, - hidden_size, dtype), + args=(num_processes, test_model, + batch_size, seq_len, hidden_size, + dtype), nprocs=nprocs) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, + test_model: type[torch.nn.Module], batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype): @@ -131,10 +197,10 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model, + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, task="auto", - tokenizer=model, + tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=True, dtype=dtype, @@ -145,7 +211,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(sequence_parallelism_pass, func_pass) - model = TestModel(hidden_size, hidden_size * 2) + model = test_model(hidden_size, hidden_size * 2) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index c9eba2b43788..c644d1e49e8e 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -27,7 +27,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int - sp_enabled: bool + enable_fusion: bool eager_mode: bool chunked_prefill: bool @@ -66,49 +66,18 @@ def detailed( task: TaskOption = "auto", load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True) - ], + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -125,19 +94,44 @@ def fast( multi_node_only: bool = False, load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ + parallel_setups=parallel_setups, + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fp8_quant( + *, + tp_base: int = 2, + pp_base: int = 1, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + parallel_setups = [] + for fusion_val in [False, True]: + parallel_setups.append( ParallelSetup(tp_size=tp_base, pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ], + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False)) + return SPTestSettings( + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -170,7 +164,7 @@ def _compare_sp( ( tp_size, pp_size, - sp_enabled, + enable_fusion, eager_mode, chunked_prefill, ) = parallel_setup @@ -239,9 +233,9 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallelism': sp_enabled, + 'enable_sequence_parallelism': True, + 'enable_fusion': enable_fusion, 'enable_noop': True, - 'enable_fusion': True, }, } @@ -290,12 +284,14 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(), } SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ] diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 618b2fe94d3a..4092bfdbeb61 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -313,8 +313,8 @@ def process(self): # 0 is always None fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} self.insert_fused_node(fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - **kwargs) + **kwargs, + epsilon=rms_node.kwargs["epsilon"]) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 17dded87fe8d..d22ff0a1a97f 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -11,31 +11,148 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.platforms import current_platform from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) -class AllReduceRMSNormPattern: +class _SequenceParallelPatternHelper: + """Base helper for sequence parallelism patterns.""" def __init__(self, epsilon: float, dtype: torch.dtype, device: str): self.epsilon = epsilon self.dtype = dtype self.device = device - - -class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return tensor_model_parallel_all_reduce(x) + + def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.reduce_scatter.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + def _all_gather(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + +class _RMSNormOpHelper(_SequenceParallelPatternHelper): + """Helper for RMSNorm operations in sequence parallelism patterns.""" + + def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, + weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + +class _RMSNormQuantOpHelper(_SequenceParallelPatternHelper): + """Helper for RMSNorm + Quantization operations in sequence parallelism patterns.""" # noqa: E501 + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + quant_op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device) + self.quant_op = quant_op + + def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, + quant_result_buffer, input_tensor, + weight_tensor, scale_tensor): + rmsnorm_out_tuple = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=rmsnorm_result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple + + def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, + input_tensor, residual_tensor, + weight_tensor, scale_tensor): + fused_add_rmsnorm_out_tuple = torch.ops.higher_order.auto_functionalized( # noqa: E501 + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] + + +class _FusedRMSNormQuantOpHelper(_SequenceParallelPatternHelper): + """Helper for Fused RMSNorm + Quantization operations in sequence parallelism patterns.""" # noqa: E501 + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + fused_rmsnorm_quant_op: torch._ops.OpOverload, + fused_add_rmsnorm_quant_op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device) + self.fused_rmsnorm_quant_op = fused_rmsnorm_quant_op + self.fused_add_rmsnorm_quant_op = fused_add_rmsnorm_quant_op + + def _functional_fused_rmsnorm_quant(self, result_buffer, input_tensor, + weight_tensor, scale_tensor): + return torch.ops.higher_order.auto_functionalized( + self.fused_rmsnorm_quant_op, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + scale=scale_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm_quant(self, result_buffer, input_tensor, + residual_tensor, weight_tensor, + scale_tensor): + return torch.ops.higher_order.auto_functionalized( + self.fused_add_rmsnorm_quant_op, + result=result_buffer, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + scale=scale_tensor, + epsilon=self.epsilon) + + +class EmbeddingAllReduceRMSNormPattern(_RMSNormOpHelper): def get_inputs(self): arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], device=self.device, dtype=torch.long) - unsqueeze = torch.rand([1, 8, 1], device=self.device, \ - dtype=self.dtype) > 0.5 - full_default = torch.zeros([1, 8, 4], device=self.device, \ - dtype=self.dtype) + unsqueeze = torch.rand([1, 8, 1], device=self.device, + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) @@ -54,14 +171,8 @@ def pattern( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - all_reduce = tensor_model_parallel_all_reduce(where) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=permute, - input=all_reduce, - weight=arg3_1, - epsilon=self.epsilon, - ) + all_reduce = self._all_reduce(where) + rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) return rmsnorm[1], all_reduce @@ -76,26 +187,13 @@ def replacement( embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) where = torch.ops.aten.where.self(unsqueeze, full_default, embedding) - - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - where, dim=0, world_size=tp_size, group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(where) rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=rmsnorm_result, - input=reduce_scatter, - weight=arg3_1, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, + arg3_1) + + all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter @@ -103,7 +201,50 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class MiddleAllReduceRMSNormPattern(_RMSNormOpHelper): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) + return rmsnorm[1], rmsnorm[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + all_gather = self._all_gather(rmsnorm[1]) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormPattern(_RMSNormOpHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -126,49 +267,172 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) + return rmsnorm[1] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + normalized = self._all_gather(rmsnorm[1]) + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +FP8_DTYPE = current_platform.fp8_dtype() + + +class EmbeddingAllReduceFusedRMSNormStaticFP8Pattern(_FusedRMSNormQuantOpHelper + ): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, + dtype, + device, + fused_rmsnorm_quant_op=torch.ops._C. + rms_norm_static_fp8_quant.default, + fused_add_rmsnorm_quant_op=torch.ops._C. + fused_add_rms_norm_static_fp8_quant.default) + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [arg2_1, mul_6, unsqueeze, full_default, result, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = self._all_reduce(where) + rmsnorm = self._functional_fused_rmsnorm_quant( + result, all_reduce, weight, scale) + return rmsnorm[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + reduce_scatter = self._reduce_scatter(where) + + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=result.dtype) + rmsnorm = self._functional_fused_rmsnorm_quant( + rmsnorm_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(rmsnorm[1]) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceFusedRMSNormStaticFP8Pattern(_FusedRMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, + dtype, + device, + fused_rmsnorm_quant_op=torch.ops._C. + rms_norm_static_fp8_quant.default, + fused_add_rmsnorm_quant_op=torch.ops._C. + fused_add_rms_norm_static_fp8_quant.default) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + result, all_reduce, residual, rms_norm_weights, scale) return rmsnorm[1], rmsnorm[2] def replacement( + result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, + scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(mm_1) + rs_result = torch.empty_like(reduce_scatter, dtype=result.dtype) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + rs_result, reduce_scatter, residual, rms_norm_weights, scale) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class LastAllReduceFusedRMSNormStaticFP8Pattern(_FusedRMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, + dtype, + device, + fused_rmsnorm_quant_op=torch.ops._C. + rms_norm_static_fp8_quant.default, + fused_add_rmsnorm_quant_op=torch.ops._C. + fused_add_rms_norm_static_fp8_quant.default) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -177,57 +441,238 @@ def get_inputs(self): rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) return [ + result, residual, mm_1, rms_norm_weights, + scale, ] def register(self, pm_pass: PatternMatcherPass): def pattern( + result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, + scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + result, all_reduce, residual, rms_norm_weights, scale) + return rmsnorm[1] - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + rs_result = torch.empty_like(reduce_scatter, dtype=result.dtype) + rmsnorm = self._functional_fused_add_rmsnorm_quant( + rs_result, reduce_scatter, residual, rms_norm_weights, scale) + normalized = self._all_gather(rmsnorm[1]) + return normalized - return rmsnorm[1] + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class EmbeddingAllReduceRMSNormStaticFP8Pattern(_RMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], + device=self.device, + dtype=self.dtype) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [ + arg2_1, mul_6, unsqueeze, full_default, rmsnorm_result, + quant_result, weight, scale + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = self._all_reduce(where) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, all_reduce, weight, scale) + return static_fp8[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + reduce_scatter = self._reduce_scatter(where) + + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=rmsnorm_result.dtype) + quant_result = torch.empty_like( + rmsnorm_result, # Output of RMSNorm + dtype=quant_result.dtype) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(static_fp8[1]) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8Pattern(_RMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1], rmsnorm_residual_out def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + all_gather = self._all_gather(static_fp8[1]) + return all_gather, rmsnorm_residual_out + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormStaticFP8Pattern(_RMSNormQuantOpHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, + scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + all_reduce = self._all_reduce(mm_1) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1] + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + normalized = self._all_gather(static_fp8[1]) return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -241,7 +686,31 @@ def __init__(self, config: VllmConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: + # RMSNorm + Static FP8 quantization patterns + fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default + EmbeddingAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + MiddleAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + LastAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + + # Fused RMSNorm + Static FP8 patterns + EmbeddingAllReduceFusedRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + MiddleAllReduceFusedRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + LastAllReduceFusedRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + # Normal RMSNorm patterns EmbeddingAllReduceRMSNormPattern( epsilon, self.model_dtype, self.device).register(self.patterns) @@ -250,6 +719,7 @@ def __init__(self, config: VllmConfig): LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() diff --git a/vllm/config.py b/vllm/config.py index 6cec97a5f11b..04a500a9a86a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4320,8 +4320,8 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.pass_config.enable_fusion = False - self.compilation_config.pass_config.enable_noop = False + # self.compilation_config.pass_config.enable_fusion = False + # self.compilation_config.pass_config.enable_noop = False self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1()