-
-
Notifications
You must be signed in to change notification settings - Fork 12.4k
TEMP changes for matching #22856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
ProExpertProg
wants to merge
4
commits into
vllm-project:main
from
neuralmagic:luka/torch-custom-op-pattern
Closed
TEMP changes for matching #22856
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
40c4388
TEMP changes for matching
ProExpertProg bc5dfaf
Add set_env_var helper and add pattern matcher debug utility
ProExpertProg 5eeb376
TEMP collective fusion hack to enable custom op, matching rms_norm an…
ProExpertProg ce2d3be
Fix default
ProExpertProg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| import vllm.plugins | ||
| from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, | ||
| FusionPass, GroupShape, QuantKey) | ||
| from vllm.compilation.noop_elimination import NoOpEliminationPass | ||
| from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, | ||
| VllmConfig) | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | ||
| CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from .backend import TestBackend | ||
|
|
||
| FP8_DTYPE = current_platform.fp8_dtype() | ||
|
|
||
|
|
||
| class TestModel(torch.nn.Module): | ||
|
|
||
| def __init__(self, hidden_size: int, eps: float, static: bool, | ||
| cutlass_fp8_enabled: bool, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.cutlass_fp8_enabled = cutlass_fp8_enabled | ||
| self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] | ||
| self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] | ||
| group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN | ||
| self.key = QuantKey(dtype=FP8_DTYPE, | ||
| static=static, | ||
| group_shape=group_shape, | ||
| symmetric=True) | ||
| if static: | ||
| self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] | ||
| else: | ||
| self.scale = [None for _ in range(2)] | ||
| self.w = [ | ||
| torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() | ||
| for _ in range(2) | ||
| ] | ||
| self.fp8_linear = Fp8LinearOp( | ||
| cutlass_fp8_supported=cutlass_fp8_enabled, | ||
| act_quant_static=static, | ||
| act_quant_group_shape=group_shape, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| resid = torch.sqrt(x) | ||
| y = self.norm[0](x) | ||
|
|
||
| return self.fp8_linear.apply(y, | ||
| self.w[0], | ||
| self.wscale[0], | ||
| input_scale=self.scale[0]) | ||
|
|
||
| def ops_in_model_before(self): | ||
| return [QUANT_OPS[self.key]] | ||
|
|
||
| def ops_in_model_after(self): | ||
| return [ | ||
| FUSED_OPS[FusedRMSQuantKey(self.key, False)], | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) | ||
| @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) | ||
| @pytest.mark.parametrize("eps", [1e-5, 1e-6]) | ||
| @pytest.mark.parametrize("static", [True, False]) | ||
| @pytest.mark.parametrize("cutlass_fp8_enabled", | ||
| [True, False] if CUTLASS_FP8_SUPPORTED else [False]) | ||
| @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], | ||
| reason="Only test on CUDA and ROCm") | ||
| def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, | ||
| cutlass_fp8_enabled): | ||
| torch.set_default_device("cuda") | ||
| torch.set_default_dtype(dtype) | ||
| torch.manual_seed(1) | ||
| maybe_create_device_identity() # needed for certain non-cutlass fp8 paths | ||
|
|
||
| vllm_config = VllmConfig(compilation_config=CompilationConfig( | ||
| level=CompilationLevel.PIECEWISE, | ||
| custom_ops=["+rms_norm"], | ||
| pass_config=PassConfig(enable_fusion=True, enable_noop=True), | ||
| )) | ||
| with vllm.config.set_current_vllm_config(vllm_config): | ||
| # Reshape pass is needed for the fusion pass to work | ||
| noop_pass = NoOpEliminationPass(vllm_config) | ||
| fusion_pass = FusionPass.instance(vllm_config) | ||
|
|
||
| backend = TestBackend(noop_pass, fusion_pass) | ||
| model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) | ||
|
|
||
| # First dimension dynamic | ||
| x = torch.rand(num_tokens, hidden_size) | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
|
|
||
| result = model(x) | ||
|
|
||
| model2 = torch.compile(model, backend=backend) | ||
| result2 = model2(x) | ||
|
|
||
| # Higher tol for dynamic, even higher for bfloat16 | ||
| if static: | ||
| ATOL, RTOL = (1e-3, 1e-3) | ||
| elif dtype == torch.float16: | ||
| ATOL, RTOL = (2e-3, 2e-3) | ||
| else: | ||
| ATOL, RTOL = (1e-2, 1e-2) | ||
|
|
||
| torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) | ||
|
|
||
| # In pre-nodes, fp8 quant should be there and fused kernels should not | ||
| # backend.check_before_ops(model.ops_in_model_before()) | ||
|
|
||
| # In post-nodes, fused kernels should be there and fp8 quant should not | ||
| backend.check_after_ops(model.ops_in_model_after()) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,9 +8,11 @@ | |
| import vllm.envs as envs | ||
| from vllm.compilation.collective_fusion import AllReduceFusionPass | ||
| from vllm.compilation.fix_functionalization import FixFunctionalizationPass | ||
| from vllm.compilation.fx_utils import find_op_nodes | ||
| from vllm.compilation.noop_elimination import NoOpEliminationPass | ||
| from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, | ||
| ModelConfig, PassConfig, VllmConfig) | ||
| ModelConfig, PassConfig, VllmConfig, | ||
| set_current_vllm_config) | ||
| from vllm.distributed import tensor_model_parallel_all_reduce | ||
| from vllm.distributed.parallel_state import (init_distributed_environment, | ||
| initialize_model_parallel) | ||
|
|
@@ -76,25 +78,23 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): | |
| self.quant_fp8 = QuantFP8(static=True, | ||
| group_shape=GroupShape.PER_TENSOR) | ||
| self.scale = torch.rand(1, dtype=torch.float32) | ||
| self.output = torch.empty((token_num, hidden_size), | ||
| dtype=torch.float32) | ||
| # self.output = torch.empty((token_num, hidden_size), | ||
| # dtype=torch.float32) | ||
|
|
||
| def forward(self, hidden_states, residual): | ||
| view = hidden_states.reshape(-1, self.hidden_size) | ||
| all_reduce = tensor_model_parallel_all_reduce(view) | ||
| norm_output, residual_output = self.norm(all_reduce, residual) | ||
| torch.ops._C.static_scaled_fp8_quant(self.output, | ||
| norm_output.contiguous(), | ||
| self.scale) | ||
| return self.output, residual_output | ||
| quant_out, _ = self.quant_fp8(norm_output, scale=self.scale) | ||
| return quant_out, residual_output | ||
|
|
||
| def ops_in_model_after(self): | ||
| return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] | ||
|
|
||
| def ops_in_model_before(self): | ||
| return [ | ||
| torch.ops.vllm.all_reduce.default, | ||
| torch.ops._C.static_scaled_fp8_quant.default | ||
| # torch.ops._C.static_scaled_fp8_quant.default | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -198,8 +198,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, | |
| initialize_model_parallel(tensor_model_parallel_size=world_size) | ||
|
|
||
| vllm_config = VllmConfig(compilation_config=CompilationConfig( | ||
| level=CompilationLevel.PIECEWISE, | ||
| custom_ops=["+rms_norm", "+quant_fp8"])) | ||
| level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) | ||
| vllm_config.compilation_config.pass_config = PassConfig( | ||
| enable_fi_allreduce_fusion=True, enable_noop=True) | ||
| vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) | ||
|
|
@@ -211,22 +210,32 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, | |
| trust_remote_code=True, | ||
| dtype=dtype, | ||
| seed=42) | ||
|
|
||
| all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) | ||
| noop_pass = NoOpEliminationPass(vllm_config) | ||
| func_pass = FixFunctionalizationPass(vllm_config) | ||
|
|
||
| backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) | ||
|
|
||
| token_num = batch_size * seq_len | ||
| model = test_model_cls(hidden_size, token_num) | ||
|
|
||
| hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) | ||
| residual = torch.randn((token_num, hidden_size), requires_grad=False) | ||
|
|
||
| compiled_model = torch.compile(model, backend=backend) | ||
| compiled_model(hidden_states, residual) | ||
|
|
||
| backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) | ||
| backend.check_after_ops(model.ops_in_model_after()) | ||
| del all_reduce_fusion_pass | ||
| with set_current_vllm_config(vllm_config): | ||
| all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) | ||
| noop_pass = NoOpEliminationPass(vllm_config) | ||
| func_pass = FixFunctionalizationPass(vllm_config) | ||
|
|
||
| backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) | ||
|
|
||
| token_num = batch_size * seq_len | ||
| model = test_model_cls(hidden_size, token_num) | ||
|
|
||
| hidden_states = torch.randn((token_num, hidden_size), | ||
| requires_grad=False) | ||
| residual = torch.randn((token_num, hidden_size), requires_grad=False) | ||
|
|
||
| compiled_model = torch.compile(model, backend=backend) | ||
| compiled_model(hidden_states, residual) | ||
|
|
||
| backend.check_before_ops(model.ops_in_model_before(), | ||
| fully_replaced=False) | ||
| backend.check_after_ops(model.ops_in_model_after()) | ||
| print(backend.graph_pre_pass) | ||
| print(backend.graph_post_pass) | ||
| for node in find_op_nodes( | ||
| torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, | ||
| backend.graph_post_pass): | ||
| print(f"{node.args=}") | ||
| print(f"{node.kwargs=}") | ||
|
|
||
|
Comment on lines
+233
to
+240
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| del all_reduce_fusion_pass | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion is commented out. In test files, assertions should not be commented out as it can hide regressions or bugs. If this check is failing, it should be investigated and fixed. Please either re-enable this check or provide a justification for its removal.