diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index c1a5b1ef..b5c1ac0d 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -260,11 +260,17 @@ def __init__( # We might have sliced outputs in the translate test. Rather than funnel the slicing # all the way down, we bail out if we can measure input vs reference if input_values is not None and input_values.shape == reference_values.shape: - self.number_changing_values = ( - (input_values != reference_values).flatten().shape[0] - ) + self.number_changing_values = (input_values != reference_values).sum() + # column information is only relevant if data is three-dimensional + if len(input_values.shape) == 3: + self.changing_column_map = (input_values != reference_values).any( + axis=2 + ) + else: + self.changing_column_map = None else: self.number_changing_values = None + self.changing_column_map = None def _compute_all_metrics( self, @@ -338,20 +344,38 @@ def report(self, file_path: str | None = None) -> list[str]: failed_indices = np.logical_not(self.success).nonzero() # List all errors to terminal and file bad_indices_count = len(failed_indices[0]) + if self.changing_column_map is not None: + if self.success.ndim == 3: + bad_column_count = ( + np.logical_not(self.success).any(axis=2) & self.changing_column_map + ).sum() + total_column_count = self.changing_column_map.sum() + bad_column_pct = round(bad_column_count / total_column_count * 100, 2) + else: + bad_column_count = None + total_column_count = None + bad_column_pct = None + else: + bad_column_count = None + total_column_count = None + bad_column_pct = None full_count = len(self.references.flatten()) failures_of_all_grid_points_pct = round( 100.0 * (bad_indices_count / full_count), 2 ) - if self.number_changing_values is not None: + if ( + self.number_changing_values is not None + and bad_indices_count is not None + and bad_column_count is not None + ): failures_of_changing_gridpoint_pct = round( 100.0 * (bad_indices_count / self.number_changing_values), 2 ) - report_local_failures = f"Failures (changing grid points) ({bad_indices_count}/{self.number_changing_values}) ({failures_of_changing_gridpoint_pct}%)\n" + report_local_failures = f"Failures: (changing columns, chainging points, all points) | {bad_column_count}/{total_column_count} - {bad_column_pct}%, {bad_indices_count}/{self.number_changing_values} - {failures_of_changing_gridpoint_pct}%, {bad_indices_count}/{full_count} - {failures_of_all_grid_points_pct}%\n" else: - report_local_failures = "" + report_local_failures = f"all grid points: {bad_indices_count}/{full_count} - {failures_of_all_grid_points_pct}%\n" report = [ f"{report_local_failures}" - f"Failures (all grid points) ({bad_indices_count}/{full_count}) ({failures_of_all_grid_points_pct}%)\n", f"Index Computed Reference " f"{'🔶 ' if not self.absolute_eps.is_default else ''}Absolute E(<{self.absolute_eps.value:.2e}) " f"{'🔶 ' if not self.relative_fraction.is_default else ''}Relative E(<{self.relative_fraction.value * 100:.2e}%) "