From 6eb3627035d72c120dc92e8d3d887a98f7311155 Mon Sep 17 00:00:00 2001 From: Federico Carrara <74301866+federico-carrara@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:56:41 +0100 Subject: [PATCH] refac: optimize `WelfordStatistics` update and finalization (#325) ### Description This PR replaces the current inefficient implementation of Welford statustics update and finalization with a faster numpy based implementation. - **What**: Refactored methods in `running_stats.py`. - **Why**: Performance. - **How**: Replacing for loops and list comprehensions with numpy vectorized ops. #### Notes No tests are there at the moment. However results provided by refactored code match the ones using the old version. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --- .../dataset/dataset_utils/running_stats.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/src/careamics/dataset/dataset_utils/running_stats.py b/src/careamics/dataset/dataset_utils/running_stats.py index 5ee40abd..23d20b31 100644 --- a/src/careamics/dataset/dataset_utils/running_stats.py +++ b/src/careamics/dataset/dataset_utils/running_stats.py @@ -34,36 +34,35 @@ def update_iterative_stats( Parameters ---------- count : NDArray - Number of elements in the array. + Number of elements in the array. Shape: (C,). mean : NDArray - Mean of the array. + Mean of the array. Shape: (C,). m2 : NDArray - Variance of the array. + Variance of the array. Shape: (C,). new_values : NDArray - New values to add to the mean and variance. + New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X). Returns ------- tuple[NDArray, NDArray, NDArray] Updated count, mean, and variance. """ - count += np.array([np.prod(channel.shape) for channel in new_values]) - # newvalues - oldMean - delta = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] + num_channels = len(new_values) - mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) - # newvalues - newMeant - delta2 = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] + # --- update channel-wise counts --- + count += np.ones_like(count) * np.prod(new_values.shape[1:]) - m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) + # --- update channel-wise mean --- + # compute (new_values - old_mean) -> shape: (C, Z*Y*X) + delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1) + mean += np.sum(delta / count.reshape(num_channels, 1), axis=1) - return (count, mean, m2) + # --- update channel-wise SoS --- + # compute (new_values - new_mean) -> shape: (C, Z*Y*X) + delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1) + m2 += np.sum(delta * delta2, axis=1) + + return count, mean, m2 def finalize_iterative_stats( @@ -74,18 +73,18 @@ def finalize_iterative_stats( Parameters ---------- count : NDArray - Number of elements in the array. + Number of elements in the array. Shape: (C,). mean : NDArray - Mean of the array. + Mean of the array. Shape: (C,). m2 : NDArray - Variance of the array. + Variance of the array. Shape: (C,). Returns ------- tuple[NDArray, NDArray] - Final mean and standard deviation. + Final channel-wise mean and standard deviation. """ - std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) + std = np.sqrt(m2 / count) if any(c < 2 for c in count): return np.full(mean.shape, np.nan), np.full(std.shape, np.nan) else: