diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index ab4d881d..2feea615 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -70,17 +70,19 @@ def _compute_errors( near_zero, ) -> npt.NDArray[np.bool_]: if self.references.dtype in (np.float64, np.int64, np.float32, np.int32): - denom = self.references - denom[self.references == 0] = self.computed[self.references == 0] + # Rule number 1: Never touch the reference data! + denom = self.references.copy() + # Avoid division by 0. If reference is 0, we expect the computed value to be 0 too. + # (abs(computed - reference) / 1.0) is a good value for the error in this case. + denom[self.references == 0] = 1.0 self._calculated_metric = np.asarray( np.abs((self.computed - self.references) / denom) ) - self._calculated_metric[denom == 0] = 0.0 elif self.references.dtype in (np.bool_, bool): self._calculated_metric = np.logical_xor(self.computed, self.references) else: raise TypeError( - f"received data with unexpected dtype {self.references.dtype}" + f"Received data with unexpected dtype `{self.references.dtype}`." ) success = np.logical_or( np.logical_and(np.isnan(self.computed), np.isnan(self.references)),