From 68c625bb1ab1053fc1f7b253f477683fe5471553 Mon Sep 17 00:00:00 2001 From: Oguzhan Buyuksolak Date: Mon, 7 Oct 2024 21:16:46 +0300 Subject: [PATCH] Histogram calculation improvements --- ai8x.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ai8x.py b/ai8x.py index 3987f6629..efefa4dc9 100644 --- a/ai8x.py +++ b/ai8x.py @@ -606,9 +606,8 @@ def _merge_hist(module): stacked_bins = torch.stack(bins_to_stack) min_edge = stacked_bins.min() max_edge = stacked_bins.max() - # 2048 is the number of bins - width = (max_edge - min_edge) / 2048 - merged_bins = torch.arange(min_edge.item(), (max_edge+width).item(), width.item()) + # 2048 is the number of bins and 2049 is the number of edges + merged_bins = torch.linspace(min_edge.item(), max_edge.item(), 2049) merged_counts = None for hist in module.hist: @@ -666,8 +665,13 @@ def _remove_outliers(module, outlier_removal_z_score=8.0): # Get mean and std of histogram hist_count = module.hist[0] hist_bins = module.hist[1] - mean = torch.sum(hist_count * hist_bins) / torch.sum(hist_count) - std = torch.sqrt(torch.sum(hist_count * (hist_bins - mean)**2) / torch.sum(hist_count)) + hist_bins_middle = [] + for i in range(len(hist_bins) - 1): + hist_bins_middle.append((hist_bins[i] + hist_bins[i+1])/2) + hist_bins_middle = torch.tensor(hist_bins_middle) + mean = torch.sum(hist_count[1:] * hist_bins_middle) / torch.sum(hist_count[1:]) + std = torch.sqrt(torch.sum(hist_count[1:] * (hist_bins_middle - mean)**2) + / torch.sum(hist_count[1:])) # When activations are very small, std ends up being 0 due to rounding. # In this case, we set std to a very small value to prevent zero element histogram. @@ -676,10 +680,10 @@ def _remove_outliers(module, outlier_removal_z_score=8.0): # Calculate bounds according to z-score upper_bound = mean + outlier_removal_z_score * std lower_bound = mean - outlier_removal_z_score * std - + hist_bins_middle = torch.cat((torch.tensor([0]), hist_bins_middle)) # Remove outliers according to bounds - hist_count[hist_bins > upper_bound] = 0 - hist_count[hist_bins < lower_bound] = 0 + hist_count[hist_bins_middle > upper_bound] = 0 + hist_count[hist_bins_middle < lower_bound] = 0 non_zero_bins = hist_count != 0 hist_count = hist_count[non_zero_bins] hist_bins = hist_bins[non_zero_bins]