@@ -730,6 +730,37 @@ def symbolic(g, x1: torch.Tensor, x2: torch.Tensor, hcom: str,
730
730
dequant_scale , pertoken_scale , comm_quant_scale_1 , comm_quant_scale_2 , antiquant_group_size , comm_turn )
731
731
732
732
733
+ class _NPUDynamicQuantOp (torch .autograd .Function ):
734
+
735
+ @staticmethod
736
+ def forward (ctx , input_dummy , smooth_scales ):
737
+ return torch .ops .npu .npu_dynamic_quant (input_dummy , smooth_scales = smooth_scales )
738
+
739
+ @staticmethod
740
+ def symbolic (g , input_dummy : Tensor , smooth_scales : Optional [Tensor ] = None ):
741
+ if smooth_scales is None :
742
+ smooth_scales = g .op ("Constant" , value_t = torch .tensor ([]).to (input_dummy .type ().dtype ()))
743
+ return g .op ("npu::NPUDynamicQuant" , input_dummy , smooth_scales , outputs = 2 )
744
+
745
+
746
+ class _NPUDynamicQuantV2Op (torch .autograd .Function ):
747
+
748
+ @staticmethod
749
+ def forward (ctx , input_dummy , smooth_scales , group_index , dst_type ):
750
+ return torch .ops .npu .npu_dynamic_quant_asymmetric (input_dummy , smooth_scales = smooth_scales ,
751
+ group_index = group_index , dst_type = dst_type )
752
+
753
+ @staticmethod
754
+ def symbolic (g , input_dummy : Tensor , smooth_scales : Optional [Tensor ] = None ,
755
+ group_index : Optional [Tensor ] = None , dst_type : torch .dtype = torch .int8 ):
756
+ if smooth_scales is None :
757
+ smooth_scales = g .op ("Constant" , value_t = torch .tensor ([]).to (input_dummy .type ().dtype ()))
758
+ if group_index is None :
759
+ group_index = g .op ("Constant" , value_t = torch .tensor ([]).to (torch .int32 ))
760
+ dst_type_i = 2 # 当前仅支持int8
761
+ return g .op ("npu::NPUDynamicQuantV2" , input_dummy , smooth_scales ,
762
+ group_index , dst_type_i = dst_type_i , outputs = 3 )
763
+
733
764
734
765
class _NPUWeightQuantBatchMatmulOP (torch .autograd .Function ):
735
766
@@ -1083,6 +1114,14 @@ def _wrapper_npu_stride_add(self, other, offset1, offset2, c1_len):
1083
1114
return _NPUStrideAddOP .apply (self , other , offset1 , offset2 , c1_len )
1084
1115
1085
1116
1117
+ def _wrapper_npu_dynamic_quant (input_dummy , smooth_scales = None ):
1118
+ return _NPUDynamicQuantOp .apply (input_dummy , smooth_scales )
1119
+
1120
+
1121
+ def _wrapper_npu_dynamic_quant_asymmetric (input_dummy , smooth_scales = None , group_index = None , dst_type = torch .int8 ):
1122
+ return _NPUDynamicQuantV2Op .apply (input_dummy , smooth_scales , group_index , dst_type )
1123
+
1124
+
1086
1125
def _wrapper_npu_gru (inputs , hx , weight_input , weight_hidden , bias_input , bias_hidden ,
1087
1126
seq_length , has_biases , num_layers , dropout , train , bidirectional , batch_first ):
1088
1127
return _NPUGruOP .apply (inputs , hx , weight_input , weight_hidden , bias_input , bias_hidden ,
@@ -1189,6 +1228,8 @@ def _add_onnx_ops():
1189
1228
torch_npu .npu_scatter = _wrapper_npu_scatter
1190
1229
torch_npu .npu_scatter_nd_update = _wrapper_npu_scatter_nd_update
1191
1230
torch_npu .npu_lstm = _wrapper_npu_lstm
1231
+ torch_npu .npu_dynamic_quant = _wrapper_npu_dynamic_quant
1232
+ torch_npu .npu_dynamic_quant_asymmetric = _wrapper_npu_dynamic_quant_asymmetric
1192
1233
torch_npu .npu_rms_norm = _wrapper_npu_rms_norm
1193
1234
torch_npu .npu_add_rms_norm = _wrapper_npu_add_rms_norm
1194
1235
torch_npu .npu_lstm_cell = _wrapper_npu_lstm_cell
0 commit comments