-
-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations #10867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
2e0031a
init
SageMoore 8a957c7
remove backend format changes
SageMoore 2913716
format
SageMoore 11c6fae
move activation_quant_kernels to the quantization dir
SageMoore 2dfecb5
added replacement unit test
SageMoore 702fa46
added kernel unit test
SageMoore 583ff4c
misc cleanup
SageMoore e5680f7
move activation quant fusion to its own pass
SageMoore 4b775c4
update test
SageMoore d5ff865
format
SageMoore c970dec
format
SageMoore 596c445
format
SageMoore 7ab3e18
format
SageMoore d347431
format
SageMoore 553d99c
format
SageMoore 774559d
format
SageMoore e2fda7f
format
SageMoore 6915fa2
minor comment fix
SageMoore 6d4b8d0
minor updates
SageMoore 6b631b0
fix fix-functionalization
SageMoore 5b78d80
add opcheck test for fused op
SageMoore 391eea5
fix fix_functionalization tests
SageMoore 546b411
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 0d79c17
fix fix_functionalization again
SageMoore 1041529
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 3198f64
format
SageMoore 58111a9
fixup includes
SageMoore 9a18085
refactor math.hpp
SageMoore 5ae5fe0
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore e051b24
fix amd build
SageMoore bfdac35
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 8514b0e
review comments and format
SageMoore ec1290a
fix amd build
SageMoore 008b725
review comments and format
SageMoore 4a0ac7e
minor test fix
SageMoore 554012e
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 4d313f6
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 635c798
format
SageMoore 584e437
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore c5ae5d7
add long strings back
SageMoore cef3530
remove whitespace
SageMoore aa5a394
misc unit test fixes
SageMoore 2b9d7b4
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 5901fec
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 5926992
remove ActivationQuantFusionPass singleton
SageMoore 241b056
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 1077775
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore 24640e1
add header
SageMoore File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| #include <ATen/cuda/CUDAContext.h> | ||
| #include <torch/all.h> | ||
| #include <c10/cuda/CUDAGuard.h> | ||
|
|
||
| #include <cmath> | ||
| #include "core/math.hpp" | ||
| #include "cuda_compat.h" | ||
| #include "dispatch_utils.h" | ||
|
|
||
| #include "quantization/fp8/common.cuh" | ||
|
|
||
| namespace vllm { | ||
|
|
||
| template <typename T> | ||
| __device__ __forceinline__ T silu_kernel(const T& x) { | ||
| // x * sigmoid(x) | ||
| return (T)(((float)x) / (1.0f + expf((float)-x))); | ||
| } | ||
|
|
||
| // Activation and gating kernel template. | ||
| template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), | ||
| typename fp8_type> | ||
| __global__ void act_and_mul_quant_kernel( | ||
| fp8_type* __restrict__ out, // [..., d] | ||
| const scalar_t* __restrict__ input, // [..., 2, d] | ||
| const float* scale, const int d) { | ||
| const int32_t blocks_per_token = gridDim.y; | ||
|
|
||
| const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); | ||
|
|
||
| // We don't expect the hidden dimension to exceed 32 bits so int32 should | ||
| // be safe here. | ||
| const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); | ||
| const int32_t elems_per_block = | ||
| round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); | ||
| const int32_t block_start = blockIdx.y * elems_per_block; | ||
SageMoore marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int32_t block_end = block_start + elems_per_block; | ||
| block_end = block_end > d ? d : block_end; | ||
|
|
||
| // token_idx is 64 bit to prevent 32 bit overflow when the number of tokens | ||
| // is very large | ||
| const int64_t token_idx = blockIdx.x; | ||
| const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; | ||
| const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; | ||
| fp8_type* __restrict__ out_ptr = out + token_idx * d; | ||
|
|
||
| // 128-bit vectorized code | ||
| const int32_t vec_loop_end = | ||
| round_to_previous_multiple_of(elems_per_128bit_load, block_end); | ||
| const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; | ||
| const int32_t vec_start_idx = block_start / elems_per_128bit_load; | ||
|
|
||
| const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr); | ||
| const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr); | ||
| int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr); | ||
|
|
||
| float inverted_scale = 1 / *scale; | ||
| #pragma unroll | ||
| for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; | ||
| vec_idx += blockDim.x) { | ||
| const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); | ||
| const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); | ||
| using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>; | ||
| using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>; | ||
|
|
||
| scalar_64bit_vec_t out_vec; | ||
| const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit); | ||
| const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit); | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < elems_per_128bit_load; i++) { | ||
| out_vec[i] = scaled_fp8_conversion<true, fp8_type>( | ||
| ACT_FN(x_vec[i]) * y_vec[i], inverted_scale); | ||
| } | ||
|
|
||
| out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec); | ||
| } | ||
|
|
||
| // Scalar cleanup code | ||
| if (block_end > vec_loop_end) { | ||
| for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; | ||
| idx += blockDim.x) { | ||
| const scalar_t x = VLLM_LDG(&x_ptr[idx]); | ||
| const scalar_t y = VLLM_LDG(&y_ptr[idx]); | ||
| out_ptr[idx] = | ||
| scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale); | ||
| } | ||
| } | ||
| } | ||
| } // namespace vllm | ||
|
|
||
| // Launch activation, gating, and quantize kernel. | ||
| #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ | ||
| int d = input.size(-1) / 2; \ | ||
| int64_t num_tokens = input.numel() / input.size(-1); \ | ||
| dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ | ||
| dim3 block(std::min(d, 512)); \ | ||
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ | ||
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ | ||
| VLLM_DISPATCH_FLOATING_TYPES( \ | ||
| input.scalar_type(), "act_and_mul_kernel", [&] { \ | ||
| VLLM_DISPATCH_FP8_TYPES( \ | ||
| out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ | ||
| vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \ | ||
| fp8_t> \ | ||
| <<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \ | ||
| input.data_ptr<scalar_t>(), \ | ||
| scale.data_ptr<float>(), d); \ | ||
| }); \ | ||
| }); | ||
|
|
||
| void silu_and_mul_quant(torch::Tensor& out, // [..., d] | ||
| torch::Tensor& input, // [..., 2 * d] | ||
| torch::Tensor& scale) { | ||
| TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); | ||
| TORCH_CHECK(input.dtype() == torch::kFloat16 || | ||
| input.dtype() == torch::kBFloat16); | ||
| TORCH_CHECK(input.size(-1) % 2 == 0); | ||
| LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import pytest | ||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm._custom_ops import scaled_fp8_quant | ||
| from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass | ||
| from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe | ||
| from vllm.config import CompilationConfig, VllmConfig | ||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
|
|
||
| from .backend import TestBackend | ||
|
|
||
|
|
||
| class TestModel(torch.nn.Module): | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.silu_and_mul = SiluAndMul() | ||
| self.scale = torch.rand(1, dtype=torch.float32) | ||
|
|
||
| def forward(self, x): | ||
| y = self.silu_and_mul(x) | ||
| x2 = scaled_fp8_quant(y, self.scale) | ||
| return x2 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", [256]) | ||
| @pytest.mark.parametrize("hidden_size", [64]) | ||
| @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", | ||
| reason="Only test on CUDA") | ||
| def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): | ||
| torch.set_default_device("cuda") | ||
| torch.set_default_dtype(torch.float16) | ||
|
|
||
| # Reshape pass is needed for the fusion pass to work | ||
| config = VllmConfig() | ||
| config.compilation_config = CompilationConfig( | ||
| pass_config=CompilationConfig.PassConfig(enable_fusion=True, | ||
| enable_reshape=True)) | ||
| fusion_pass = ActivationQuantFusionPass(config) | ||
|
|
||
| backend = TestBackend(fusion_pass) | ||
| model = TestModel() | ||
|
|
||
| # First dimension dynamic | ||
| x = torch.rand(num_tokens, hidden_size) | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
|
|
||
| result = model(x) | ||
|
|
||
| model2 = torch.compile(model, backend=backend) | ||
| result2 = model2(x) | ||
|
|
||
| # Check that it gives the same answer | ||
| torch.testing.assert_close(result[0].to(dtype=torch.float16), | ||
| result2[0].to(dtype=torch.float16), | ||
| atol=1e-3, | ||
| rtol=1e-3) | ||
|
|
||
| # Check substitution worked | ||
| pre_nodes = backend.graph_pre_pass.nodes | ||
| post_nodes = backend.graph_post_pass.nodes | ||
|
|
||
| silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default | ||
| fp8_quant = torch.ops._C.static_scaled_fp8_quant.default | ||
|
|
||
| # In pre-nodes, fp8 quant should be present and fused kernels should not | ||
| assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None | ||
| find_auto_fn(pre_nodes, fp8_quant) | ||
|
|
||
| # In post-nodes, fused kernels should be present and fp8 quant should not | ||
| find_auto_fn(post_nodes, silu_and_mul_quant) | ||
| assert find_auto_fn_maybe(post_nodes, fp8_quant) is None |
SageMoore marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import pytest | ||
| import torch | ||
|
|
||
| import vllm._custom_ops as ops | ||
| from tests.kernels.utils import opcheck | ||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
|
|
||
| DTYPES = [torch.bfloat16, torch.float16] | ||
| QUANT_DTYPES = [torch.float8_e4m3fn] | ||
| NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing | ||
| HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing | ||
| SEEDS = [0] | ||
| CUDA_DEVICES = [ | ||
| f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) | ||
| ] | ||
|
|
||
|
|
||
| def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, | ||
| scale: torch.Tensor) -> torch.Tensor: | ||
| silu_and_mul_out = silu_and_mul.forward_native(x) | ||
| out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) | ||
| return out | ||
|
|
||
|
|
||
| def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
| out_shape = (x.shape[0], x.shape[1] // 2) | ||
| out = torch.empty(out_shape, | ||
| dtype=torch.torch.float8_e4m3fn, | ||
| device=x.device) | ||
| torch.ops._C.silu_and_mul_quant(out, x, scale) | ||
| return out | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
| @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) | ||
| @pytest.mark.parametrize("seed", SEEDS) | ||
| @pytest.mark.parametrize("device", CUDA_DEVICES) | ||
| @torch.inference_mode() | ||
| def test_silu_and_mul( | ||
| num_tokens: int, | ||
| hidden_size: int, | ||
| dtype: torch.dtype, | ||
| quant_dtype: torch.dtype, | ||
| seed: int, | ||
| device: str, | ||
| ) -> None: | ||
| torch.random.manual_seed(seed) | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.manual_seed(seed) | ||
| torch.set_default_device(device) | ||
|
|
||
| layer = SiluAndMul() | ||
|
|
||
| # Make inputs | ||
| scale = (torch.randn((1), device=device, dtype=torch.float32)) | ||
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) | ||
|
|
||
| ref_out = ref_impl(layer, x, scale) | ||
| ops_out = ops_impl(x, scale) | ||
|
|
||
| assert ref_out.dtype == quant_dtype | ||
| assert ops_out.dtype == quant_dtype | ||
| assert ref_out.shape == ops_out.shape | ||
| assert torch.allclose(ref_out.to(dtype=torch.float32), | ||
| ops_out.to(dtype=torch.float32)) | ||
| opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.