diff --git a/keras/preprocessing/image.py b/keras/preprocessing/image.py index c403d896b29..329f8151c68 100644 --- a/keras/preprocessing/image.py +++ b/keras/preprocessing/image.py @@ -748,9 +748,10 @@ def fit(self, x, if self.zca_whitening: flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) - sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] - u, s, _ = linalg.svd(sigma) - self.principal_components = np.dot(np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T) + num_examples = flat_x.shape[0] + u, s, vt = linalg.svd(flat_x / np.sqrt(num_examples)) + s_expand = np.hstack((s, np.zeros(vt.shape[0] - num_examples, dtype=flat_x.dtype))) + self.principal_components = (vt.T / np.sqrt(s_expand ** 2 + self.zca_epsilon)).dot(vt) class Iterator(Sequence):