Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions examples/save_sharded_state_310.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@

Example usage:

python save_sharded_state.py \
python save_sharded_state_310.py \
--model /path/to/load \
--tensor-parallel-size 8 \
--output /path/to/save \
--enable-compress \
--compress-process-num 8
--compress-process-num 8 \
--enforce-eager \
--dtype float16 \
--quantization ascend

Then, the model can be loaded with

Expand Down Expand Up @@ -140,29 +143,30 @@ def get_quant_description(json_file: str) -> dict:
return quant_desc


def update_quant_description(json_file: str) -> None:
def update_quant_description(ori_json_file: str, target_json_file: str) -> None:
"""
Update quantization types in JSON configuration file based on update mapping.

Args:
json_file: Path to the JSON configuration file
ori_json_file: Path to the JSON configuration file
target_json_file: Path to the JSON configuration file to be saved

Raises:
FileNotFoundError: If the JSON file does not exist
RuntimeError: If JSON parsing fails or required keys are missing
"""
config_path = Path(json_file)
config_path = Path(ori_json_file)
try:
with config_path.open("r", encoding="utf-8") as file:
json_data = json.load(file)
except (FileNotFoundError, json.JSONDecodeError) as e:
raise RuntimeError(f"Failed to read configuration file {json_file}: {e}")
raise RuntimeError(f"Failed to read configuration file {ori_json_file}: {e}")

original_quant_type = json_data.get("model_quant_type")
if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
raise RuntimeError(
f"Cannot update quantization type. "
f"Original type '{original_quant_type}' not found or not supported for update in {json_file}."
f"Original type '{original_quant_type}' not found or not supported for update in {ori_json_file}."
)
updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]

Expand All @@ -175,12 +179,12 @@ def update_quant_description(json_file: str) -> None:
updated_config[key] = value

try:
new_file_path = config_path.parent / "quant_model_description.json"
new_file_path = Path(target_json_file)
with new_file_path.open("w", encoding="utf-8") as file:
json.dump(updated_config, file, indent=2, ensure_ascii=False)
os.remove(json_file)
os.remove(ori_json_file)
except OSError as e:
raise RuntimeError(f"Failed to write updated configuration to {json_file}: {e}")
raise RuntimeError(f"Failed to write updated configuration to {target_json_file}: {e}")


def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
Expand Down Expand Up @@ -214,9 +218,6 @@ def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -
compressor.run()
if p.exists():
os.remove(p)
ori_quant_desc_file = p.parent / "quant_model_description.json"
if ori_quant_desc_file.exists():
os.rename(str(ori_quant_desc_file), str(ori_quant_desc_file.parent / "ori_quant_model_description.json"))
compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
return True
except Exception as e:
Expand Down Expand Up @@ -248,6 +249,10 @@ def main(args):
# 4. Compression Logic
parameters_map_fpath = output_dir / "parameters_type_map.json"
if args.enable_compress:
quant_desc_file = output_dir / "quant_model_description.json"
backup_quant_desc_file = output_dir / "ori_quant_model_description.json"
if quant_desc_file.exists():
os.rename(str(quant_desc_file), str(backup_quant_desc_file))
quant_desc = get_quant_description(str(parameters_map_fpath))
quant_type = quant_desc["model_quant_type"]
if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
Expand All @@ -269,7 +274,7 @@ def main(args):
for p in tasks:
p.join()

update_quant_description(os.path.join(args.output, "ori_quant_model_description.json"))
update_quant_description(str(backup_quant_desc_file), str(quant_desc_file))
print("Compression completed successfully.")
else:
print(f"Skipping compression: Unsupported type {quant_type}")
Expand Down
122 changes: 122 additions & 0 deletions tests/ut/_310p/quantization/test_w8a8sc_310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from unittest.mock import MagicMock, patch

import pytest
import torch

from tests.ut.base import TestBase
from vllm_ascend._310p.quantization.methods.w8a8sc import AscendW8A8SCLinearMethod310


class TestAscendW8A8SCLinearMethod310(TestBase):

def setUp(self):
self.method = AscendW8A8SCLinearMethod310()

def test_get_weight_310(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (10 * 20, ))
self.assertEqual(weight["index"].dtype, torch.int8)
index_len = math.ceil(10 / 256) * math.ceil(20 / 128) * 8
self.assertEqual(weight["index"].shape, (index_len, ))
self.assertEqual(weight["info"].dtype, torch.int64)
self.assertEqual(weight["info"].shape, (5, ))

def test_get_pertensor_param_310(self):
params = self.method.get_pertensor_param(torch.float16)
self.assertEqual(params["input_scale"].dtype, torch.float16)
self.assertEqual(params["input_offset"].dtype, torch.int8)
self.assertEqual(params["input_scale"].shape, (1, ))
self.assertEqual(params["input_offset"].shape, (1, ))

def test_get_perchannel_param_310(self):
params = self.method.get_perchannel_param(10, torch.float16)

self.assertEqual(params["quant_bias"].dtype, torch.int32)
self.assertEqual(params["deq_scale"].dtype, torch.int64)
self.assertEqual(params["quant_bias"].shape, (10, ))
self.assertEqual(params["deq_scale"].shape, (10, ))

@pytest.mark.skip(
"Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.")
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_matmul_compress_dequant")
def test_apply_with_x_not_int8_310(self, mock_matmul_compress_dequant,
mock_quantize):
layer = MagicMock()
layer.aclnn_input_scale = torch.randn(256)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale
layer.aclnn_input_offset = torch.randint(-128,
127, (256, ),
dtype=torch.int8)
layer.weight = torch.randint(-128,
127, (256 * 128, ),
dtype=torch.int8)
layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8)
layer.deq_scale = torch.randn(128)
layer.quant_bias = torch.randint(-128, 127, (256, ))
layer.params_dtype = torch.float16

x = torch.randn(32, 128)
expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8)
mock_quantize.return_value = expect_x_output

expected_y_output = torch.randn(32, 256)
mock_matmul_compress_dequant.return_value = expected_y_output

output = self.method.apply(layer, x, tp_rank=0)

mock_quantize.assert_called_with(x, layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset)
mock_matmul_compress_dequant.assert_called_with(
expect_x_output, layer.weight, layer.index, layer.quant_bias,
layer.deq_scale)
self.assertTrue(torch.equal(output, expected_y_output))

@pytest.mark.skip(
"Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.")
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_matmul_compress_dequant")
def test_apply_with_x_is_int8_310(self, mock_matmul_compress_dequant,
mock_quantize):
layer = MagicMock()
layer.aclnn_input_scale = torch.randn(256)
layer.aclnn_input_offset = torch.randint(-128,
127, (256, ),
dtype=torch.int8)
layer.weight = torch.randint(-128,
127, (256 * 128, ),
dtype=torch.int8)
layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8)
layer.deq_scale = torch.randn(128)
layer.quant_bias = torch.randint(-128, 127, (256, ))
layer.params_dtype = torch.float16

x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)

expected_y_output = torch.randn(32, 256)
mock_matmul_compress_dequant.return_value = expected_y_output

output = self.method.apply(layer, x, tp_rank=0)

mock_quantize.assert_not_called()
mock_matmul_compress_dequant.assert_called_with(
x, layer.weight, layer.index, layer.quant_bias, layer.deq_scale)
self.assertTrue(torch.equal(output, expected_y_output))
1 change: 1 addition & 0 deletions vllm_ascend/_310p/quantization/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
w8a8_dynamic, # noqa: F401
w8a8_static, # noqa: F401
w8a8s, # noqa: F401
w8a8sc, # noqa: F401
)
116 changes: 116 additions & 0 deletions vllm_ascend/_310p/quantization/methods/w8a8sc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#

import math
from typing import Any

import torch
import torch_npu
from vllm.distributed import get_tensor_model_parallel_rank

from vllm_ascend.ops.linear import AscendRowParallelLinear
from vllm_ascend.quantization.methods.base import AscendLinearScheme

from .registry import register_scheme


@register_scheme("W8A8SC", "linear")
class AscendW8A8SCLinearMethod310(AscendLinearScheme):
"""310P-only W8A8SC static linear scheme.

Notes:
- This scheme is discovered via 310P local registry.
"""

def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.float16,
) -> dict[str, Any]:
"""
Get the weight tensors for the W8A8SC quantization scheme.

Args:
input_size: Size of the input dimension (k)
output_size: Size of the output dimension (n)
params_dtype: Data type for parameters, default is torch.float16

Returns:
A dictionary containing:
- "weight": The compressed weight tensor with shape [c], where c is greater than 0
and not larger than k * n
- "index": Compression index generated simultaneously with compressed weights,
with shape [x], where x = k_index * n_index * 8, k_index = ceil(k1 / tilingK),
n_index = ceil(n1 / tilingN), k1 = k / 32, n1 = n / 16
- "info": Compression information with length 5, containing compression block
information tilingN, tilingK, original shape of the pre-compression x2 matrix,
and identifier for the compression block traversal direction
"""
self.input_size = input_size
index_len = math.ceil(input_size / 256) * math.ceil(output_size / 128) * 8
return {
"weight": torch.empty(input_size * output_size, dtype=torch.int8),
Comment thread
pu-zhe marked this conversation as resolved.
"index": torch.empty(index_len, dtype=torch.int8),
"info": torch.empty(5, dtype=torch.int64),
}

def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"input_scale": torch.empty(1, dtype=params_dtype),
"input_offset": torch.empty(1, dtype=torch.int8),
}

def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"quant_bias": torch.empty(output_size, dtype=torch.int32),
"deq_scale": torch.empty(output_size, dtype=torch.int64),
}

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
x = torch.ops.vllm.quantize(
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)

return torch_npu.npu_matmul_compress_dequant(
x,
layer.weight,
layer.index,
layer.quant_bias,
layer.deq_scale,
)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.aclnn_input_scale = layer.input_scale.data.repeat(self.input_size)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data
layer.aclnn_input_offset = layer.input_offset.data.repeat(self.input_size).to(layer.aclnn_input_scale.dtype)
layer.deq_scale.data = layer.deq_scale.data.unsqueeze(0).to(torch.uint64)
layer.quant_bias.data = layer.quant_bias.data.unsqueeze(0)
# Only apply bias on row_parallel_linear when tp_rank is 0.
# torch_npu.npu_matmul_compress_dequant's quant_bias cannot be None.
if isinstance(layer, AscendRowParallelLinear) and get_tensor_model_parallel_rank() != 0:
layer.quant_bias.data = torch.zeros_like(layer.quant_bias)