Skip to content

Commit

Permalink
Change "threshold" name to "activation_threshold"
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Sep 25, 2024
1 parent fcbe1ff commit a06393e
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,16 +692,16 @@ def init_threshold_module(module, outlier_removal_z_score):
"""
_merge_hist(module)
_remove_outliers(module, outlier_removal_z_score)
module.threshold = nn.Parameter(module.hist[1].abs().max().log2().ceil().exp2(),
requires_grad=False)
module.activation_threshold = nn.Parameter(module.hist[1].abs().max().log2().ceil().exp2(),
requires_grad=False)


def calc_threshold(module, iterations=5, bits=8):
"""
Iteratively calculate threshold for activation quantization
"""
e_min = torch.inf
t_nc = module.threshold
t_nc = module.activation_threshold
t = None

for i in range(iterations):
Expand All @@ -711,7 +711,7 @@ def calc_threshold(module, iterations=5, bits=8):
e_min = e_i
t = t_i

module.threshold = nn.Parameter(torch.log2(t), requires_grad=False)
module.activation_threshold = nn.Parameter(torch.log2(t), requires_grad=False)


class QuantizationAwareModule(nn.Module):
Expand Down Expand Up @@ -761,7 +761,9 @@ def __init__(
self.pooling = pooling

self.output_shift = nn.Parameter(torch.tensor([0.]), requires_grad=False)
self.threshold = nn.Parameter(torch.tensor(0.), requires_grad=False)
# Activation threshold determined during QAT, used in quantization
# It determines the range of quantization
self.activation_threshold = nn.Parameter(torch.tensor(0.), requires_grad=False)
self.final_scale = nn.Parameter(torch.tensor(0.), requires_grad=False)

self.init_module(weight_bits, bias_bits, quantize_activation,
Expand Down Expand Up @@ -842,7 +844,7 @@ def forward(self, x): # pylint: disable=arguments-differ
# Quantized checkpoint will have subtracted threshold from output shift
# Therefore, it shouldn't be done again in simulate mode
if not dev.simulate:
out_shift = (out_shift - self.threshold).clamp(min=-15., max=15.)
out_shift = (out_shift - self.activation_threshold).clamp(min=-15., max=15.)

out_scale = self.calc_out_scale(out_shift)
x = self._conv_forward( # pylint: disable=protected-access
Expand Down Expand Up @@ -1780,7 +1782,7 @@ class Eltwise(nn.Module):
def __init__(self, f, clamp_activation=False):
super().__init__()
self.f = f
self.threshold = None
self.activation_threshold = None
self.set_clamp(clamp_activation)

def set_clamp(self, clamp_activation):
Expand Down Expand Up @@ -2186,7 +2188,7 @@ def apply_scales(model):
for name1, module1 in model.named_modules():
if isinstance(module1, Eltwise):
if adds[name] == name1:
module.threshold = module1.threshold
module.activation_threshold = module1.activation_threshold
break

# Find the maximum threshold from the layers that are concatenated together
Expand All @@ -2195,16 +2197,16 @@ def apply_scales(model):
if isinstance(module, QuantizationAwareModule):
if name in concats:
if concat_thresholds.get(concats[name]) is None:
concat_thresholds[concats[name]] = module.threshold
elif module.threshold > concat_thresholds[concats[name]]:
concat_thresholds[concats[name]] = module.threshold
concat_thresholds[concats[name]] = module.activation_threshold
elif module.activation_threshold > concat_thresholds[concats[name]]:
concat_thresholds[concats[name]] = module.activation_threshold

# Apply the maximum threshold to the layers that are concatenated together
for name, module in model.named_modules():
if isinstance(module, QuantizationAwareModule):
if name in concats:
module.threshold = nn.Parameter(concat_thresholds[concats[name]],
requires_grad=False)
module.activation_threshold = nn.Parameter(concat_thresholds[concats[name]],
requires_grad=False)

# Find weight sharing layers and apply the maximum threshold from the multiple passes
shared_threshold = {}
Expand All @@ -2216,20 +2218,20 @@ def apply_scales(model):
if isinstance(module1, QuantizationAwareModule):
if prev == name1:
if shared_threshold.get(name) is None:
shared_threshold[name] = module1.threshold
elif module1.threshold > shared_threshold[name]:
shared_threshold[name] = module1.threshold
shared_threshold[name] = module1.activation_threshold
elif module1.activation_threshold > shared_threshold[name]:
shared_threshold[name] = module1.activation_threshold
for prev in prevs[name]:
for name1, module1 in model.named_modules():
if isinstance(module1, QuantizationAwareModule):
if prev == name1:
module1.threshold = shared_threshold[name]
module1.activation_threshold = shared_threshold[name]

# Get the thresholds after overrides
thresholds = {}
for name, module in model.named_modules():
if isinstance(module, QuantizationAwareModule):
thresholds[name] = module.threshold
thresholds[name] = module.activation_threshold

# Adjust bias and threshold values according to the previous layers,
# and set the final scale value for output layers
Expand All @@ -2244,12 +2246,12 @@ def apply_scales(model):
if module.op is not None and module.op.bias is not None:
module.op.bias = nn.Parameter(module.op.bias /
torch.exp2(thresholds[name1]))
module.threshold = nn.Parameter((module.threshold -
thresholds[name1]),
requires_grad=False)
module.activation_threshold = \
nn.Parameter((module.activation_threshold - thresholds[name1]),
requires_grad=False)
if module.wide:
module.final_scale = nn.Parameter(thresholds[name] -
module.threshold,
module.activation_threshold,
requires_grad=False)
else:
module.final_scale = nn.Parameter(thresholds[name],
Expand Down

0 comments on commit a06393e

Please sign in to comment.