diff --git a/tests/compile/test_fusion2.py b/tests/compile/test_fusion2.py new file mode 100644 index 000000000000..de8f0b665414 --- /dev/null +++ b/tests/compile/test_fusion2.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.envs as envs +import vllm.plugins +from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, + FusionPass, GroupShape, QuantKey) +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) +from vllm.platforms import current_platform + +from .backend import TestBackend + +FP8_DTYPE = current_platform.fp8_dtype() + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size: int, eps: float, static: bool, + cutlass_fp8_enabled: bool, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cutlass_fp8_enabled = cutlass_fp8_enabled + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + self.key = QuantKey(dtype=FP8_DTYPE, + static=static, + group_shape=group_shape, + symmetric=True) + if static: + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + else: + self.scale = [None for _ in range(2)] + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(2) + ] + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=cutlass_fp8_enabled, + act_quant_static=static, + act_quant_group_shape=group_shape, + ) + + def forward(self, x): + resid = torch.sqrt(x) + y = self.norm[0](x) + + return self.fp8_linear.apply(y, + self.w[0], + self.wscale[0], + input_scale=self.scale[0]) + + def ops_in_model_before(self): + return [QUANT_OPS[self.key]] + + def ops_in_model_after(self): + return [ + FUSED_OPS[FusedRMSQuantKey(self.key, False)], + ] + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) +@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("cutlass_fp8_enabled", + [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], + reason="Only test on CUDA and ROCm") +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, + cutlass_fp8_enabled): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + )) + with vllm.config.set_current_vllm_config(vllm_config): + # Reshape pass is needed for the fusion pass to work + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass) + model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Higher tol for dynamic, even higher for bfloat16 + if static: + ATOL, RTOL = (1e-3, 1e-3) + elif dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + # In pre-nodes, fp8 quant should be there and fused kernels should not + # backend.check_before_ops(model.ops_in_model_before()) + + # In post-nodes, fused kernels should be there and fp8 quant should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 4c3cf6c2a10c..aea3db60767b 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -8,9 +8,11 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, - ModelConfig, PassConfig, VllmConfig) + ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) @@ -76,17 +78,15 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), - dtype=torch.float32) + # self.output = torch.empty((token_num, hidden_size), + # dtype=torch.float32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant(self.output, - norm_output.contiguous(), - self.scale) - return self.output, residual_output + quant_out, _ = self.quant_fp8(norm_output, scale=self.scale) + return quant_out, residual_output def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -94,7 +94,7 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default + # torch.ops._C.static_scaled_fp8_quant.default ] @@ -198,8 +198,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, initialize_model_parallel(tensor_model_parallel_size=world_size) vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"])) + level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) vllm_config.compilation_config.pass_config = PassConfig( enable_fi_allreduce_fusion=True, enable_noop=True) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) @@ -211,22 +210,32 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, trust_remote_code=True, dtype=dtype, seed=42) - - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - noop_pass = NoOpEliminationPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) - - token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) - - hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) - - compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) - - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) - backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + with set_current_vllm_config(vllm_config): + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) + + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) + + hidden_states = torch.randn((token_num, hidden_size), + requires_grad=False) + residual = torch.randn((token_num, hidden_size), requires_grad=False) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states, residual) + + backend.check_before_ops(model.ops_in_model_before(), + fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + print(backend.graph_pre_pass) + print(backend.graph_post_pass) + for node in find_op_nodes( + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + backend.graph_post_pass): + print(f"{node.args=}") + print(f"{node.kwargs=}") + + del all_reduce_fusion_pass diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6ae50245ed3a..9b4b3fa0944e 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -18,6 +18,8 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from ..model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from ..model_executor.layers.quantization.utils.quant_utils import GroupShape from .vllm_inductor_pass import VllmInductorPass FP8_DTYPE = current_platform.fp8_dtype() @@ -715,6 +717,10 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str, self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + # TODO HACK + self.quant_fp8._forward_method = self.quant_fp8.forward_native def register(self, pm_pass: PatternMatcherPass): @@ -725,17 +731,23 @@ def get_inputs(): rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.quant_dtype) + # quant_result = torch.empty([1, 8, 4], + # device=self.device, + # dtype=self.quant_dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + return [ + input, + rmsnorm_result, + # quant_result, + weight, + scale + ] def pattern( input: torch.Tensor, rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, + # quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): @@ -746,18 +758,19 @@ def pattern( weight=weight, epsilon=self.epsilon) - quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale) + quant_out, _ = self.quant_fp8(rmsnorm_out_tuple[1], scale=scale) # quant_out, allreduce_output - return quant_out_tuple[1], all_reduce + return quant_out, all_reduce - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement( + input: torch.Tensor, + result_rms: torch.Tensor, + # quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor): residual = torch.zeros_like(input) + quant_result = torch.empty_like(input, dtype=FP8_DTYPE) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -794,6 +807,10 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str, self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + # TODO HACK + self.quant_fp8._forward_method = self.quant_fp8.forward_native def register(self, pm_pass: PatternMatcherPass): @@ -804,15 +821,15 @@ def get_inputs(): device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([4, 4], - device=self.device, - dtype=self.quant_dtype) + # quant_result = torch.empty([4, 4], + # device=self.device, + # dtype=self.quant_dtype) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) return [ - quant_result, + # quant_result, residual, input, weight, @@ -820,7 +837,7 @@ def get_inputs(): ] def pattern( - quant_result: torch.Tensor, + # quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -835,18 +852,21 @@ def pattern( residual=residual, weight=weight, epsilon=self.epsilon) - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale) + + quant_out, _ = self.quant_fp8(fused_add_rmsnorm_out_tuple[1], + scale=scale) # quant_out, allreduce_output - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + return quant_out, fused_add_rmsnorm_out_tuple[2] - def replacement(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement( + # quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + quant_result = torch.empty_like(input, dtype=FP8_DTYPE) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1120,36 +1140,39 @@ def __init__(self, config: VllmConfig): self.device, self.allreduce_params, ).register(self.patterns) - if current_platform.has_device_capability(100): - AllReduceFusedRMSNormStaticQuantNVFP4Pattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) - AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) - AllReduceRMSNormPattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) - AllReduceFusedAddRMSNormPattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) + # if current_platform.has_device_capability(100): + # AllReduceFusedRMSNormStaticQuantNVFP4Pattern( + # epsilon, + # self.model_dtype, + # self.device, + # self.allreduce_params, + # ).register(self.patterns) + # AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( + # epsilon, + # self.model_dtype, + # self.device, + # self.allreduce_params, + # ).register(self.patterns) + # AllReduceRMSNormPattern( + # epsilon, + # self.model_dtype, + # self.device, + # self.allreduce_params, + # ).register(self.patterns) + # AllReduceFusedAddRMSNormPattern( + # epsilon, + # self.model_dtype, + # self.device, + # self.allreduce_params, + # ).register(self.patterns) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() + if path := config.compilation_config.debug_dump_path: + with open(f"{path}/patterns.txt", 'w') as f: + print(self.patterns.patterns, file=f) self.disabled = False def __call__(self, graph: fx.Graph): diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 3dec939c2835..d4bc612c8332 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -11,6 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform @@ -237,6 +238,63 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, pm_pass) +class RMSNormStaticQuantPattern2(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey( + dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric)) + super().__init__(epsilon, fused_key) + + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + + def register(self, pm_pass: PatternMatcherPass): + # Cannot use methods, as the self argument affects tracing + def pattern(result_rms: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + result2, _ = self.quant_fp8(at1[1], scale=scale) + # result + return result2 + + def replacement(result_rms: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + result = torch.empty(input.size(), + dtype=self.quant_dtype, + device=input.device) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result + return at[1] + + inputs = [ + # torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) + + class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): def __init__(self, @@ -569,6 +627,9 @@ def __init__(self, config: VllmConfig): # Fuse rms_norm + static fp8 quant RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern2(epsilon, + FP8_DTYPE).register(self.patterns) # Matches for patterns below have 2 or more outputs, # so we need to process them manually (see process_matches) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e07e52be9fdf..55ae09b43776 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import ExitStack + from torch import fx as fx +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import set_env_var if current_platform.is_cuda_alike(): from .fusion import FusionPass @@ -43,13 +47,20 @@ def __init__(self): self.passes: list[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): - shape = get_pass_context().runtime_shape - for pass_ in self.passes: - if pass_.is_applicable_for_shape(shape): - pass_(graph) - - # always run fix_functionalization last - self.fix_functionalization(graph) + with ExitStack() as stack: + if envs.VLLM_PATTERN_MATCH_DEBUG is not None: + # and get_tensor_model_parallel_rank() == 0: + stack.enter_context( + set_env_var('TORCHINDUCTOR_PATTERN_MATCH_DEBUG', + envs.VLLM_PATTERN_MATCH_DEBUG)) + + shape = get_pass_context().runtime_shape + for pass_ in self.passes: + if pass_.is_applicable_for_shape(shape): + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config diff --git a/vllm/envs.py b/vllm/envs.py index 145ec3495a0c..6c7e22dc40b1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -160,6 +160,9 @@ VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None + VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE: bool = True + VLLM_USE_STANDALONE_COMPILE: bool = True + VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None def get_default_cache_root(): @@ -363,6 +366,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", + # Debug pattern matching inside custom passes + "VLLM_PATTERN_MATCH_DEBUG": + lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG"), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 095829db8394..5e4bdc801014 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3346,3 +3346,16 @@ def decorate_logs(process_name: Optional[str] = None) -> None: pid = os.getpid() _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) + + +@contextlib.contextmanager +def set_env_var(key, value): + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old