diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index b5c1ac0d..b04f09c0 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -27,8 +27,8 @@ def __init__( reference_values: np.ndarray, computed_values: np.ndarray, ): - self.references = np.atleast_1d(reference_values) - self.computed = np.atleast_1d(computed_values) + self.references = np.squeeze(np.atleast_1d(reference_values)) + self.computed = np.squeeze(np.atleast_1d(computed_values)) self.check = False @abstractmethod