diff --git a/src/careamics/dataset/dataset_utils/running_stats.py b/src/careamics/dataset/dataset_utils/running_stats.py index 5ee40abd..23d20b31 100644 --- a/src/careamics/dataset/dataset_utils/running_stats.py +++ b/src/careamics/dataset/dataset_utils/running_stats.py @@ -34,36 +34,35 @@ def update_iterative_stats( Parameters ---------- count : NDArray - Number of elements in the array. + Number of elements in the array. Shape: (C,). mean : NDArray - Mean of the array. + Mean of the array. Shape: (C,). m2 : NDArray - Variance of the array. + Variance of the array. Shape: (C,). new_values : NDArray - New values to add to the mean and variance. + New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X). Returns ------- tuple[NDArray, NDArray, NDArray] Updated count, mean, and variance. """ - count += np.array([np.prod(channel.shape) for channel in new_values]) - # newvalues - oldMean - delta = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] + num_channels = len(new_values) - mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) - # newvalues - newMeant - delta2 = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] + # --- update channel-wise counts --- + count += np.ones_like(count) * np.prod(new_values.shape[1:]) - m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) + # --- update channel-wise mean --- + # compute (new_values - old_mean) -> shape: (C, Z*Y*X) + delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1) + mean += np.sum(delta / count.reshape(num_channels, 1), axis=1) - return (count, mean, m2) + # --- update channel-wise SoS --- + # compute (new_values - new_mean) -> shape: (C, Z*Y*X) + delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1) + m2 += np.sum(delta * delta2, axis=1) + + return count, mean, m2 def finalize_iterative_stats( @@ -74,18 +73,18 @@ def finalize_iterative_stats( Parameters ---------- count : NDArray - Number of elements in the array. + Number of elements in the array. Shape: (C,). mean : NDArray - Mean of the array. + Mean of the array. Shape: (C,). m2 : NDArray - Variance of the array. + Variance of the array. Shape: (C,). Returns ------- tuple[NDArray, NDArray] - Final mean and standard deviation. + Final channel-wise mean and standard deviation. """ - std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) + std = np.sqrt(m2 / count) if any(c < 2 for c in count): return np.full(mean.shape, np.nan), np.full(std.shape, np.nan) else: