Skip to content

Commit

Permalink
refac: optimize WelfordStatistics update and finalization (#325)
Browse files Browse the repository at this point in the history
### Description

This PR replaces the current inefficient implementation of Welford
statustics update and finalization with a faster numpy based
implementation.

- **What**: Refactored methods in `running_stats.py`.
- **Why**: Performance.
- **How**: Replacing for loops and list comprehensions with numpy
vectorized ops.

#### Notes

No tests are there at the moment. However results provided by refactored
code match the ones using the old version.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
federico-carrara authored Dec 12, 2024
1 parent bf1e48a commit 6eb3627
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions src/careamics/dataset/dataset_utils/running_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,35 @@ def update_iterative_stats(
Parameters
----------
count : NDArray
Number of elements in the array.
Number of elements in the array. Shape: (C,).
mean : NDArray
Mean of the array.
Mean of the array. Shape: (C,).
m2 : NDArray
Variance of the array.
Variance of the array. Shape: (C,).
new_values : NDArray
New values to add to the mean and variance.
New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X).
Returns
-------
tuple[NDArray, NDArray, NDArray]
Updated count, mean, and variance.
"""
count += np.array([np.prod(channel.shape) for channel in new_values])
# newvalues - oldMean
delta = [
np.subtract(v.flatten(), [m] * len(v.flatten()))
for v, m in zip(new_values, mean)
]
num_channels = len(new_values)

mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
# newvalues - newMeant
delta2 = [
np.subtract(v.flatten(), [m] * len(v.flatten()))
for v, m in zip(new_values, mean)
]
# --- update channel-wise counts ---
count += np.ones_like(count) * np.prod(new_values.shape[1:])

m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
# --- update channel-wise mean ---
# compute (new_values - old_mean) -> shape: (C, Z*Y*X)
delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
mean += np.sum(delta / count.reshape(num_channels, 1), axis=1)

return (count, mean, m2)
# --- update channel-wise SoS ---
# compute (new_values - new_mean) -> shape: (C, Z*Y*X)
delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
m2 += np.sum(delta * delta2, axis=1)

return count, mean, m2


def finalize_iterative_stats(
Expand All @@ -74,18 +73,18 @@ def finalize_iterative_stats(
Parameters
----------
count : NDArray
Number of elements in the array.
Number of elements in the array. Shape: (C,).
mean : NDArray
Mean of the array.
Mean of the array. Shape: (C,).
m2 : NDArray
Variance of the array.
Variance of the array. Shape: (C,).
Returns
-------
tuple[NDArray, NDArray]
Final mean and standard deviation.
Final channel-wise mean and standard deviation.
"""
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
std = np.sqrt(m2 / count)
if any(c < 2 for c in count):
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
else:
Expand Down

0 comments on commit 6eb3627

Please sign in to comment.