Skip to content

Commit

Permalink
Histogram calculation improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Oct 7, 2024
1 parent 6eea328 commit 68c625b
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,8 @@ def _merge_hist(module):
stacked_bins = torch.stack(bins_to_stack)
min_edge = stacked_bins.min()
max_edge = stacked_bins.max()
# 2048 is the number of bins
width = (max_edge - min_edge) / 2048
merged_bins = torch.arange(min_edge.item(), (max_edge+width).item(), width.item())
# 2048 is the number of bins and 2049 is the number of edges
merged_bins = torch.linspace(min_edge.item(), max_edge.item(), 2049)
merged_counts = None

for hist in module.hist:
Expand Down Expand Up @@ -666,8 +665,13 @@ def _remove_outliers(module, outlier_removal_z_score=8.0):
# Get mean and std of histogram
hist_count = module.hist[0]
hist_bins = module.hist[1]
mean = torch.sum(hist_count * hist_bins) / torch.sum(hist_count)
std = torch.sqrt(torch.sum(hist_count * (hist_bins - mean)**2) / torch.sum(hist_count))
hist_bins_middle = []
for i in range(len(hist_bins) - 1):
hist_bins_middle.append((hist_bins[i] + hist_bins[i+1])/2)
hist_bins_middle = torch.tensor(hist_bins_middle)
mean = torch.sum(hist_count[1:] * hist_bins_middle) / torch.sum(hist_count[1:])
std = torch.sqrt(torch.sum(hist_count[1:] * (hist_bins_middle - mean)**2)
/ torch.sum(hist_count[1:]))

# When activations are very small, std ends up being 0 due to rounding.
# In this case, we set std to a very small value to prevent zero element histogram.
Expand All @@ -676,10 +680,10 @@ def _remove_outliers(module, outlier_removal_z_score=8.0):
# Calculate bounds according to z-score
upper_bound = mean + outlier_removal_z_score * std
lower_bound = mean - outlier_removal_z_score * std

hist_bins_middle = torch.cat((torch.tensor([0]), hist_bins_middle))
# Remove outliers according to bounds
hist_count[hist_bins > upper_bound] = 0
hist_count[hist_bins < lower_bound] = 0
hist_count[hist_bins_middle > upper_bound] = 0
hist_count[hist_bins_middle < lower_bound] = 0
non_zero_bins = hist_count != 0
hist_count = hist_count[non_zero_bins]
hist_bins = hist_bins[non_zero_bins]
Expand Down

0 comments on commit 68c625b

Please sign in to comment.