Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions ndsl/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,17 @@ def __init__(
# We might have sliced outputs in the translate test. Rather than funnel the slicing
# all the way down, we bail out if we can measure input vs reference
if input_values is not None and input_values.shape == reference_values.shape:
self.number_changing_values = (
(input_values != reference_values).flatten().shape[0]
)
self.number_changing_values = (input_values != reference_values).sum()
# column information is only relevant if data is three-dimensional
if len(input_values.shape) == 3:
Comment thread
FlorianDeconinck marked this conversation as resolved.
self.changing_column_map = (input_values != reference_values).any(
axis=2
)
else:
self.changing_column_map = None
else:
self.number_changing_values = None
self.changing_column_map = None

def _compute_all_metrics(
self,
Expand Down Expand Up @@ -338,20 +344,38 @@ def report(self, file_path: str | None = None) -> list[str]:
failed_indices = np.logical_not(self.success).nonzero()
# List all errors to terminal and file
bad_indices_count = len(failed_indices[0])
if self.changing_column_map is not None:
if self.success.ndim == 3:
bad_column_count = (
np.logical_not(self.success).any(axis=2) & self.changing_column_map
).sum()
total_column_count = self.changing_column_map.sum()
bad_column_pct = round(bad_column_count / total_column_count * 100, 2)
else:
bad_column_count = None
total_column_count = None
bad_column_pct = None
else:
bad_column_count = None
total_column_count = None
bad_column_pct = None
full_count = len(self.references.flatten())
failures_of_all_grid_points_pct = round(
100.0 * (bad_indices_count / full_count), 2
)
if self.number_changing_values is not None:
if (
self.number_changing_values is not None
and bad_indices_count is not None
and bad_column_count is not None
):
failures_of_changing_gridpoint_pct = round(
100.0 * (bad_indices_count / self.number_changing_values), 2
)
report_local_failures = f"Failures (changing grid points) ({bad_indices_count}/{self.number_changing_values}) ({failures_of_changing_gridpoint_pct}%)\n"
report_local_failures = f"Failures: (changing columns, chainging points, all points) | {bad_column_count}/{total_column_count} - {bad_column_pct}%, {bad_indices_count}/{self.number_changing_values} - {failures_of_changing_gridpoint_pct}%, {bad_indices_count}/{full_count} - {failures_of_all_grid_points_pct}%\n"
else:
report_local_failures = ""
report_local_failures = f"all grid points: {bad_indices_count}/{full_count} - {failures_of_all_grid_points_pct}%\n"
report = [
f"{report_local_failures}"
f"Failures (all grid points) ({bad_indices_count}/{full_count}) ({failures_of_all_grid_points_pct}%)\n",
f"Index Computed Reference "
f"{'🔶 ' if not self.absolute_eps.is_default else ''}Absolute E(<{self.absolute_eps.value:.2e}) "
f"{'🔶 ' if not self.relative_fraction.is_default else ''}Relative E(<{self.relative_fraction.value * 100:.2e}%) "
Expand Down