Skip to content

Commit ec4ea08

Browse files
wangqkit-is-a-robot
wangqk
authored andcommitted
!14029 add op npu_dynamic_quant_asymmetric、npu_dynamic_quant
Merge pull request !14029 from wangqk/dev_asymmetrical_dynamic_quant_master
1 parent 5e2be70 commit ec4ea08

File tree

5 files changed

+148
-0
lines changed

5 files changed

+148
-0
lines changed

test/allowlist_for_publicAPI.json

+2
Original file line numberDiff line numberDiff line change
@@ -2822,6 +2822,8 @@
28222822
"npu_quant_scatter",
28232823
"npu_scatter_nd_update_",
28242824
"npu_swiglu",
2825+
"npu_dynamic_quant",
2826+
"npu_dynamic_quant_asymmetric",
28252827
"npu_yolo_boxes_encode",
28262828
"npu_yolo_boxes_encode",
28272829
"npu_weight_quant_batchmatmul",

test/onnx/test_wrapper_onnx_ops.py

+42
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,48 @@ def export_onnx(onnx_model_name):
13371337
export_onnx(onnx_model_name)
13381338
assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name)))
13391339

1340+
@SupportedDevices(['Ascend910B'])
1341+
def test_wrapper_npu_dynamic_quant(self):
1342+
class Model(torch.nn.Module):
1343+
def __init__(self):
1344+
super(Model, self).__init__()
1345+
1346+
def forward(self, input_dummy, smooth_scales_dummy):
1347+
output, scale = torch_npu.npu_dynamic_quant(input_dummy, smooth_scales=smooth_scales_dummy)
1348+
return output, scale
1349+
1350+
def export_onnx(onnx_model_name):
1351+
input_dummy = torch.rand(4, 1024, 512).uniform_(-3, 3).npu().to(torch.float16)
1352+
smooth_scales_dummy = torch.rand(512).uniform_(-3, 3).npu().to(torch.float16)
1353+
model = Model().to("npu")
1354+
model(input_dummy, smooth_scales_dummy)
1355+
self.onnx_export(model, (input_dummy, smooth_scales_dummy), onnx_model_name,
1356+
["input", "smooth_scale_dummy"], ["output", "scale"])
1357+
onnx_model_name = "model_npu_dynamic_quant.onnx"
1358+
export_onnx(onnx_model_name)
1359+
assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name)))
1360+
1361+
@SupportedDevices(['Ascend910B'])
1362+
def test_wrapper_npu_dynamic_quant_asymmetric(self):
1363+
class Model(torch.nn.Module):
1364+
def __init__(self):
1365+
super(Model, self).__init__()
1366+
1367+
def forward(self, input_dummy, smooth_scales_dummy):
1368+
output, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(input_dummy, smooth_scales=smooth_scales_dummy)
1369+
return output, scale, offset
1370+
1371+
def export_onnx(onnx_model_name):
1372+
input_dummy = torch.rand(4, 1024, 512).uniform_(-3, 3).npu().to(torch.float16)
1373+
smooth_scales_dummy = torch.rand(512).uniform_(-3, 3).npu().to(torch.float16)
1374+
model = Model().to("npu")
1375+
model(input_dummy, smooth_scales_dummy)
1376+
self.onnx_export(model, (input_dummy, smooth_scales_dummy), onnx_model_name,
1377+
["input", "smooth_scale_dummy"], ["output", "scale", "offset"])
1378+
onnx_model_name = "model_npu_dynamic_quant_asymmetric.onnx"
1379+
export_onnx(onnx_model_name)
1380+
assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name)))
1381+
13401382
@SupportedDevices(['Ascend910B'])
13411383
def test_wrapper_npu_weight_quant_batchmatmul(self):
13421384
class Model(torch.nn.Module):

test/test_fake_tensor.py

+42
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,48 @@ def test_npu_ffn_meta(self):
15951595
self.assertTrue(x.shape == res.shape)
15961596

15971597

1598+
class TestNpuDynamicQuant(TestCase):
1599+
def test_npu_dynamic_quant(self):
1600+
with FakeTensorMode():
1601+
input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
1602+
smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)
1603+
1604+
output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
1605+
scale = torch.randn((4, 2048)).npu().to(torch.float32)
1606+
1607+
actual_output, actual_scale = torch_npu.npu_dynamic_quant(input_npu, smooth_scales=smooth_scales_npu)
1608+
1609+
self.assertEqual(actual_output.dtype, output.dtype)
1610+
self.assertEqual(actual_output.shape, output.shape)
1611+
self.assertEqual(actual_output.device, output.device)
1612+
self.assertEqual(actual_scale.dtype, scale.dtype)
1613+
self.assertEqual(actual_scale.shape, scale.shape)
1614+
self.assertEqual(actual_scale.device, scale.device)
1615+
1616+
1617+
class TestDynamicQuantAsymmetric(TestCase):
1618+
def test_npu_dynamic_quant_asymmetric(self):
1619+
with FakeTensorMode():
1620+
input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
1621+
smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)
1622+
1623+
output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
1624+
scale = torch.randn((4, 2048)).npu().to(torch.float32)
1625+
offset = torch.randn((4, 2048)).npu().to(torch.float32)
1626+
1627+
actual_output, actual_scale, actual_offset = torch_npu.npu_dynamic_quant_asymmetric(input_npu, smooth_scales=smooth_scales_npu)
1628+
1629+
self.assertEqual(actual_output.dtype, output.dtype)
1630+
self.assertEqual(actual_output.shape, output.shape)
1631+
self.assertEqual(actual_output.device, output.device)
1632+
self.assertEqual(actual_scale.dtype, scale.dtype)
1633+
self.assertEqual(actual_scale.shape, scale.shape)
1634+
self.assertEqual(actual_scale.device, scale.device)
1635+
self.assertEqual(actual_offset.dtype, offset.dtype)
1636+
self.assertEqual(actual_offset.shape, offset.shape)
1637+
self.assertEqual(actual_offset.device, offset.device)
1638+
1639+
15981640
class TestGroupedMatmul(TestCase):
15991641
def test_npu_grouped_matmul_meta_0(self):
16001642
with FakeTensorMode():

torch_npu/meta/_meta_registrations.py

+21
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,27 @@ def npu_quantize_meta(self, scales, zero_points, dtype, axis=1, div_mode=True):
627627
return torch.empty_like(self, dtype=torch.int8)
628628

629629

630+
@impl(m, "npu_dynamic_quant")
631+
def npu_dynamic_quant(input_dummy, *, smooth_scales=None):
632+
dim_num = input_dummy.dim()
633+
scale_shape = []
634+
for dim in range(dim_num - 1):
635+
scale_shape.append(input_dummy.size(dim))
636+
return (torch.empty_like(input_dummy, dtype=torch.int8),
637+
input_dummy.new_empty(scale_shape, dtype=torch.float32))
638+
639+
640+
@impl(m, "npu_dynamic_quant_asymmetric")
641+
def npu_dynamic_quant_asymmetric(input_dummy, *, smooth_scales=None, group_index=None, dst_type=torch.int8):
642+
dim_num = input_dummy.dim()
643+
scale_offset_shape = []
644+
for dim in range(dim_num - 1):
645+
scale_offset_shape.append(input_dummy.size(dim))
646+
return (torch.empty_like(input_dummy, dtype=torch.int8),
647+
input_dummy.new_empty(scale_offset_shape, dtype=torch.float32),
648+
input_dummy.new_empty(scale_offset_shape, dtype=torch.float32))
649+
650+
630651
@impl(m, "npu_moe_compute_expert_tokens")
631652
def npu_moe_compute_expert_tokens_meta(sorted_experts, num_experts=1):
632653
out = torch.zeros(num_experts, dtype=torch.int32, device='meta')

torch_npu/onnx/wrapper_onnx_ops.py

+41
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,37 @@ def symbolic(g, x1: torch.Tensor, x2: torch.Tensor, hcom: str,
730730
dequant_scale, pertoken_scale, comm_quant_scale_1, comm_quant_scale_2, antiquant_group_size, comm_turn)
731731

732732

733+
class _NPUDynamicQuantOp(torch.autograd.Function):
734+
735+
@staticmethod
736+
def forward(ctx, input_dummy, smooth_scales):
737+
return torch.ops.npu.npu_dynamic_quant(input_dummy, smooth_scales=smooth_scales)
738+
739+
@staticmethod
740+
def symbolic(g, input_dummy: Tensor, smooth_scales: Optional[Tensor] = None):
741+
if smooth_scales is None:
742+
smooth_scales = g.op("Constant", value_t=torch.tensor([]).to(input_dummy.type().dtype()))
743+
return g.op("npu::NPUDynamicQuant", input_dummy, smooth_scales, outputs=2)
744+
745+
746+
class _NPUDynamicQuantV2Op(torch.autograd.Function):
747+
748+
@staticmethod
749+
def forward(ctx, input_dummy, smooth_scales, group_index, dst_type):
750+
return torch.ops.npu.npu_dynamic_quant_asymmetric(input_dummy, smooth_scales=smooth_scales,
751+
group_index=group_index, dst_type=dst_type)
752+
753+
@staticmethod
754+
def symbolic(g, input_dummy: Tensor, smooth_scales: Optional[Tensor] = None,
755+
group_index: Optional[Tensor] = None, dst_type: torch.dtype = torch.int8):
756+
if smooth_scales is None:
757+
smooth_scales = g.op("Constant", value_t=torch.tensor([]).to(input_dummy.type().dtype()))
758+
if group_index is None:
759+
group_index = g.op("Constant", value_t=torch.tensor([]).to(torch.int32))
760+
dst_type_i = 2 # 当前仅支持int8
761+
return g.op("npu::NPUDynamicQuantV2", input_dummy, smooth_scales,
762+
group_index, dst_type_i=dst_type_i, outputs=3)
763+
733764

734765
class _NPUWeightQuantBatchMatmulOP(torch.autograd.Function):
735766

@@ -1083,6 +1114,14 @@ def _wrapper_npu_stride_add(self, other, offset1, offset2, c1_len):
10831114
return _NPUStrideAddOP.apply(self, other, offset1, offset2, c1_len)
10841115

10851116

1117+
def _wrapper_npu_dynamic_quant(input_dummy, smooth_scales=None):
1118+
return _NPUDynamicQuantOp.apply(input_dummy, smooth_scales)
1119+
1120+
1121+
def _wrapper_npu_dynamic_quant_asymmetric(input_dummy, smooth_scales=None, group_index=None, dst_type=torch.int8):
1122+
return _NPUDynamicQuantV2Op.apply(input_dummy, smooth_scales, group_index, dst_type)
1123+
1124+
10861125
def _wrapper_npu_gru(inputs, hx, weight_input, weight_hidden, bias_input, bias_hidden,
10871126
seq_length, has_biases, num_layers, dropout, train, bidirectional, batch_first):
10881127
return _NPUGruOP.apply(inputs, hx, weight_input, weight_hidden, bias_input, bias_hidden,
@@ -1189,6 +1228,8 @@ def _add_onnx_ops():
11891228
torch_npu.npu_scatter = _wrapper_npu_scatter
11901229
torch_npu.npu_scatter_nd_update = _wrapper_npu_scatter_nd_update
11911230
torch_npu.npu_lstm = _wrapper_npu_lstm
1231+
torch_npu.npu_dynamic_quant = _wrapper_npu_dynamic_quant
1232+
torch_npu.npu_dynamic_quant_asymmetric = _wrapper_npu_dynamic_quant_asymmetric
11921233
torch_npu.npu_rms_norm = _wrapper_npu_rms_norm
11931234
torch_npu.npu_add_rms_norm = _wrapper_npu_add_rms_norm
11941235
torch_npu.npu_lstm_cell = _wrapper_npu_lstm_cell

0 commit comments

Comments
 (0)