-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[Feat] [310p] Support w8a8sc quantization method #7075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| # | ||
| # Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import math | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from tests.ut.base import TestBase | ||
| from vllm_ascend._310p.quantization.methods.w8a8sc import AscendW8A8SCLinearMethod310 | ||
|
|
||
|
|
||
| class TestAscendW8A8SCLinearMethod310(TestBase): | ||
|
|
||
| def setUp(self): | ||
| self.method = AscendW8A8SCLinearMethod310() | ||
|
|
||
| def test_get_weight_310(self): | ||
| weight = self.method.get_weight(10, 20) | ||
| self.assertEqual(weight["weight"].dtype, torch.int8) | ||
| self.assertEqual(weight["weight"].shape, (10 * 20, )) | ||
| self.assertEqual(weight["index"].dtype, torch.int8) | ||
| index_len = math.ceil(10 / 256) * math.ceil(20 / 128) * 8 | ||
| self.assertEqual(weight["index"].shape, (index_len, )) | ||
| self.assertEqual(weight["info"].dtype, torch.int64) | ||
| self.assertEqual(weight["info"].shape, (5, )) | ||
|
|
||
| def test_get_pertensor_param_310(self): | ||
| params = self.method.get_pertensor_param(torch.float16) | ||
| self.assertEqual(params["input_scale"].dtype, torch.float16) | ||
| self.assertEqual(params["input_offset"].dtype, torch.int8) | ||
| self.assertEqual(params["input_scale"].shape, (1, )) | ||
| self.assertEqual(params["input_offset"].shape, (1, )) | ||
|
|
||
| def test_get_perchannel_param_310(self): | ||
| params = self.method.get_perchannel_param(10, torch.float16) | ||
|
|
||
| self.assertEqual(params["quant_bias"].dtype, torch.int32) | ||
| self.assertEqual(params["deq_scale"].dtype, torch.int64) | ||
| self.assertEqual(params["quant_bias"].shape, (10, )) | ||
| self.assertEqual(params["deq_scale"].shape, (10, )) | ||
|
|
||
| @pytest.mark.skip( | ||
| "Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.") | ||
| @patch("torch.ops.vllm.quantize") | ||
| @patch("torch_npu.npu_matmul_compress_dequant") | ||
| def test_apply_with_x_not_int8_310(self, mock_matmul_compress_dequant, | ||
| mock_quantize): | ||
| layer = MagicMock() | ||
| layer.aclnn_input_scale = torch.randn(256) | ||
| layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale | ||
| layer.aclnn_input_offset = torch.randint(-128, | ||
| 127, (256, ), | ||
| dtype=torch.int8) | ||
| layer.weight = torch.randint(-128, | ||
| 127, (256 * 128, ), | ||
| dtype=torch.int8) | ||
| layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8) | ||
| layer.deq_scale = torch.randn(128) | ||
| layer.quant_bias = torch.randint(-128, 127, (256, )) | ||
| layer.params_dtype = torch.float16 | ||
|
|
||
| x = torch.randn(32, 128) | ||
| expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8) | ||
| mock_quantize.return_value = expect_x_output | ||
|
|
||
| expected_y_output = torch.randn(32, 256) | ||
| mock_matmul_compress_dequant.return_value = expected_y_output | ||
|
|
||
| output = self.method.apply(layer, x, tp_rank=0) | ||
|
|
||
| mock_quantize.assert_called_with(x, layer.aclnn_input_scale, | ||
| layer.aclnn_input_scale_reciprocal, | ||
| layer.aclnn_input_offset) | ||
| mock_matmul_compress_dequant.assert_called_with( | ||
| expect_x_output, layer.weight, layer.index, layer.quant_bias, | ||
| layer.deq_scale) | ||
| self.assertTrue(torch.equal(output, expected_y_output)) | ||
|
|
||
| @pytest.mark.skip( | ||
| "Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.") | ||
| @patch("torch.ops.vllm.quantize") | ||
| @patch("torch_npu.npu_matmul_compress_dequant") | ||
| def test_apply_with_x_is_int8_310(self, mock_matmul_compress_dequant, | ||
| mock_quantize): | ||
| layer = MagicMock() | ||
| layer.aclnn_input_scale = torch.randn(256) | ||
| layer.aclnn_input_offset = torch.randint(-128, | ||
| 127, (256, ), | ||
| dtype=torch.int8) | ||
| layer.weight = torch.randint(-128, | ||
| 127, (256 * 128, ), | ||
| dtype=torch.int8) | ||
| layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8) | ||
| layer.deq_scale = torch.randn(128) | ||
| layer.quant_bias = torch.randint(-128, 127, (256, )) | ||
| layer.params_dtype = torch.float16 | ||
|
|
||
| x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) | ||
|
|
||
| expected_y_output = torch.randn(32, 256) | ||
| mock_matmul_compress_dequant.return_value = expected_y_output | ||
|
|
||
| output = self.method.apply(layer, x, tp_rank=0) | ||
|
|
||
| mock_quantize.assert_not_called() | ||
| mock_matmul_compress_dequant.assert_called_with( | ||
| x, layer.weight, layer.index, layer.quant_bias, layer.deq_scale) | ||
| self.assertTrue(torch.equal(output, expected_y_output)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,4 +19,5 @@ | |
| w8a8_dynamic, # noqa: F401 | ||
| w8a8_static, # noqa: F401 | ||
| w8a8s, # noqa: F401 | ||
| w8a8sc, # noqa: F401 | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| # | ||
| # Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
|
|
||
| import math | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch_npu | ||
| from vllm.distributed import get_tensor_model_parallel_rank | ||
|
|
||
| from vllm_ascend.ops.linear import AscendRowParallelLinear | ||
| from vllm_ascend.quantization.methods.base import AscendLinearScheme | ||
|
|
||
| from .registry import register_scheme | ||
|
|
||
|
|
||
| @register_scheme("W8A8SC", "linear") | ||
| class AscendW8A8SCLinearMethod310(AscendLinearScheme): | ||
| """310P-only W8A8SC static linear scheme. | ||
|
|
||
| Notes: | ||
| - This scheme is discovered via 310P local registry. | ||
| """ | ||
|
|
||
| def get_weight( | ||
| self, | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype = torch.float16, | ||
| ) -> dict[str, Any]: | ||
| """ | ||
| Get the weight tensors for the W8A8SC quantization scheme. | ||
|
|
||
| Args: | ||
| input_size: Size of the input dimension (k) | ||
| output_size: Size of the output dimension (n) | ||
| params_dtype: Data type for parameters, default is torch.float16 | ||
|
|
||
| Returns: | ||
| A dictionary containing: | ||
| - "weight": The compressed weight tensor with shape [c], where c is greater than 0 | ||
| and not larger than k * n | ||
| - "index": Compression index generated simultaneously with compressed weights, | ||
| with shape [x], where x = k_index * n_index * 8, k_index = ceil(k1 / tilingK), | ||
| n_index = ceil(n1 / tilingN), k1 = k / 32, n1 = n / 16 | ||
| - "info": Compression information with length 5, containing compression block | ||
| information tilingN, tilingK, original shape of the pre-compression x2 matrix, | ||
| and identifier for the compression block traversal direction | ||
| """ | ||
| self.input_size = input_size | ||
| index_len = math.ceil(input_size / 256) * math.ceil(output_size / 128) * 8 | ||
| return { | ||
| "weight": torch.empty(input_size * output_size, dtype=torch.int8), | ||
| "index": torch.empty(index_len, dtype=torch.int8), | ||
| "info": torch.empty(5, dtype=torch.int64), | ||
| } | ||
|
|
||
| def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: | ||
| return { | ||
| "input_scale": torch.empty(1, dtype=params_dtype), | ||
| "input_offset": torch.empty(1, dtype=torch.int8), | ||
| } | ||
|
|
||
| def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: | ||
| return { | ||
| "quant_bias": torch.empty(output_size, dtype=torch.int32), | ||
| "deq_scale": torch.empty(output_size, dtype=torch.int64), | ||
| } | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| tp_rank: int | None = 0, | ||
| ) -> torch.Tensor: | ||
| if x.dtype != torch.int8: | ||
| x = torch.ops.vllm.quantize( | ||
| x, | ||
| layer.aclnn_input_scale, | ||
| layer.aclnn_input_scale_reciprocal, | ||
| layer.aclnn_input_offset, | ||
| ) | ||
|
|
||
| return torch_npu.npu_matmul_compress_dequant( | ||
| x, | ||
| layer.weight, | ||
| layer.index, | ||
| layer.quant_bias, | ||
| layer.deq_scale, | ||
| ) | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| layer.aclnn_input_scale = layer.input_scale.data.repeat(self.input_size) | ||
| layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data | ||
| layer.aclnn_input_offset = layer.input_offset.data.repeat(self.input_size).to(layer.aclnn_input_scale.dtype) | ||
| layer.deq_scale.data = layer.deq_scale.data.unsqueeze(0).to(torch.uint64) | ||
| layer.quant_bias.data = layer.quant_bias.data.unsqueeze(0) | ||
| # Only apply bias on row_parallel_linear when tp_rank is 0. | ||
| # torch_npu.npu_matmul_compress_dequant's quant_bias cannot be None. | ||
| if isinstance(layer, AscendRowParallelLinear) and get_tensor_model_parallel_rank() != 0: | ||
| layer.quant_bias.data = torch.zeros_like(layer.quant_bias) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.