diff --git a/ai8x.py b/ai8x.py index 5316c6c1e..3987f6629 100644 --- a/ai8x.py +++ b/ai8x.py @@ -692,8 +692,8 @@ def init_threshold_module(module, outlier_removal_z_score): """ _merge_hist(module) _remove_outliers(module, outlier_removal_z_score) - module.threshold = nn.Parameter(module.hist[1].abs().max().log2().ceil().exp2(), - requires_grad=False) + module.activation_threshold = nn.Parameter(module.hist[1].abs().max().log2().ceil().exp2(), + requires_grad=False) def calc_threshold(module, iterations=5, bits=8): @@ -701,7 +701,7 @@ def calc_threshold(module, iterations=5, bits=8): Iteratively calculate threshold for activation quantization """ e_min = torch.inf - t_nc = module.threshold + t_nc = module.activation_threshold t = None for i in range(iterations): @@ -711,7 +711,7 @@ def calc_threshold(module, iterations=5, bits=8): e_min = e_i t = t_i - module.threshold = nn.Parameter(torch.log2(t), requires_grad=False) + module.activation_threshold = nn.Parameter(torch.log2(t), requires_grad=False) class QuantizationAwareModule(nn.Module): @@ -761,7 +761,9 @@ def __init__( self.pooling = pooling self.output_shift = nn.Parameter(torch.tensor([0.]), requires_grad=False) - self.threshold = nn.Parameter(torch.tensor(0.), requires_grad=False) + # Activation threshold determined during QAT, used in quantization + # It determines the range of quantization + self.activation_threshold = nn.Parameter(torch.tensor(0.), requires_grad=False) self.final_scale = nn.Parameter(torch.tensor(0.), requires_grad=False) self.init_module(weight_bits, bias_bits, quantize_activation, @@ -842,7 +844,7 @@ def forward(self, x): # pylint: disable=arguments-differ # Quantized checkpoint will have subtracted threshold from output shift # Therefore, it shouldn't be done again in simulate mode if not dev.simulate: - out_shift = (out_shift - self.threshold).clamp(min=-15., max=15.) + out_shift = (out_shift - self.activation_threshold).clamp(min=-15., max=15.) out_scale = self.calc_out_scale(out_shift) x = self._conv_forward( # pylint: disable=protected-access @@ -1780,7 +1782,7 @@ class Eltwise(nn.Module): def __init__(self, f, clamp_activation=False): super().__init__() self.f = f - self.threshold = None + self.activation_threshold = None self.set_clamp(clamp_activation) def set_clamp(self, clamp_activation): @@ -2186,7 +2188,7 @@ def apply_scales(model): for name1, module1 in model.named_modules(): if isinstance(module1, Eltwise): if adds[name] == name1: - module.threshold = module1.threshold + module.activation_threshold = module1.activation_threshold break # Find the maximum threshold from the layers that are concatenated together @@ -2195,16 +2197,16 @@ def apply_scales(model): if isinstance(module, QuantizationAwareModule): if name in concats: if concat_thresholds.get(concats[name]) is None: - concat_thresholds[concats[name]] = module.threshold - elif module.threshold > concat_thresholds[concats[name]]: - concat_thresholds[concats[name]] = module.threshold + concat_thresholds[concats[name]] = module.activation_threshold + elif module.activation_threshold > concat_thresholds[concats[name]]: + concat_thresholds[concats[name]] = module.activation_threshold # Apply the maximum threshold to the layers that are concatenated together for name, module in model.named_modules(): if isinstance(module, QuantizationAwareModule): if name in concats: - module.threshold = nn.Parameter(concat_thresholds[concats[name]], - requires_grad=False) + module.activation_threshold = nn.Parameter(concat_thresholds[concats[name]], + requires_grad=False) # Find weight sharing layers and apply the maximum threshold from the multiple passes shared_threshold = {} @@ -2216,20 +2218,20 @@ def apply_scales(model): if isinstance(module1, QuantizationAwareModule): if prev == name1: if shared_threshold.get(name) is None: - shared_threshold[name] = module1.threshold - elif module1.threshold > shared_threshold[name]: - shared_threshold[name] = module1.threshold + shared_threshold[name] = module1.activation_threshold + elif module1.activation_threshold > shared_threshold[name]: + shared_threshold[name] = module1.activation_threshold for prev in prevs[name]: for name1, module1 in model.named_modules(): if isinstance(module1, QuantizationAwareModule): if prev == name1: - module1.threshold = shared_threshold[name] + module1.activation_threshold = shared_threshold[name] # Get the thresholds after overrides thresholds = {} for name, module in model.named_modules(): if isinstance(module, QuantizationAwareModule): - thresholds[name] = module.threshold + thresholds[name] = module.activation_threshold # Adjust bias and threshold values according to the previous layers, # and set the final scale value for output layers @@ -2244,12 +2246,12 @@ def apply_scales(model): if module.op is not None and module.op.bias is not None: module.op.bias = nn.Parameter(module.op.bias / torch.exp2(thresholds[name1])) - module.threshold = nn.Parameter((module.threshold - - thresholds[name1]), - requires_grad=False) + module.activation_threshold = \ + nn.Parameter((module.activation_threshold - thresholds[name1]), + requires_grad=False) if module.wide: module.final_scale = nn.Parameter(thresholds[name] - - module.threshold, + module.activation_threshold, requires_grad=False) else: module.final_scale = nn.Parameter(thresholds[name],