Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ndsl/stencils/testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ Upon failure, the test will drop a `netCDF` file in a `./.translate-errors` dire

## Environment variables

- `PACE_TEST_N_THRESHOLD_SAMPLES`: Upon failure the system will try to perturb the output in an attempt to check for numerical instability. This means re-running the test for N samples. Default is `10`, `0` or less turns this feature off.
- `NDSL_TEST_N_THRESHOLD_SAMPLES`: Upon failure the system will try to perturb the output in an attempt to check for numerical instability. This means re-running the test for N samples. Default is `0`, which turns this feature off.
14 changes: 6 additions & 8 deletions ndsl/stencils/testing/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
OUTDIR = "./.translate-outputs"
GPU_MAX_ERR = 1e-10
GPU_NEAR_ZERO = 1e-15
N_THRESHOLD_SAMPLES = int(os.getenv("NDSL_TEST_N_THRESHOLD_SAMPLES", 0))


def platform():
Expand Down Expand Up @@ -89,7 +90,7 @@ def process_override(threshold_overrides, testobj, test_name, backend):
testobj.skip_test = bool(match["skip_test"])
elif len(matches) > 1:
raise Exception(
"misconfigured threshold overrides file, more than 1 specification for "
"Misconfigured threshold overrides file, more than 1 specification for "
+ test_name
+ " with backend="
+ backend
Expand All @@ -98,9 +99,6 @@ def process_override(threshold_overrides, testobj, test_name, backend):
)


N_THRESHOLD_SAMPLES = int(os.getenv("PACE_TEST_N_THRESHOLD_SAMPLES", 10))


def get_thresholds(testobj, input_data):
_get_thresholds(testobj.compute, input_data)

Expand Down Expand Up @@ -158,7 +156,7 @@ def test_sequential_savepoint(
):
if case.testobj is None:
pytest.xfail(
f"no translate object available for savepoint {case.savepoint_name}"
f"No translate object available for savepoint {case.savepoint_name}."
)
stencil_config = StencilConfig(
compilation_config=CompilationConfig(backend=backend),
Expand All @@ -178,7 +176,7 @@ def test_sequential_savepoint(
if case.testobj.skip_test:
return
if not case.exists:
pytest.skip(f"Data at rank {case.grid.rank} does not exists")
pytest.skip(f"Data at rank {case.grid.rank} does not exist.")
input_data = dataset_to_dict(case.ds_in)
input_names = (
case.testobj.serialnames(case.testobj.in_vars["data_vars"])
Expand All @@ -188,7 +186,7 @@ def test_sequential_savepoint(
input_data = {name: input_data[name] for name in input_names}
except KeyError as e:
raise KeyError(
f"Variable {e} was described in the translate test but cannot be found in the NetCDF"
f"Variable {e} was described in the translate test but cannot be found in the NetCDF."
)
original_input_data = copy.deepcopy(input_data)
# give the user a chance to load data from other savepoints to allow
Expand All @@ -208,7 +206,7 @@ def test_sequential_savepoint(
try:
ref_data = all_ref_data[varname]
except KeyError:
raise KeyError(f"Output {varname} couldn't be found in output data")
raise KeyError(f"Output {varname} couldn't be found in output data.")
if hasattr(case.testobj, "subset_output"):
ref_data = case.testobj.subset_output(varname, ref_data)
with subtests.test(varname=varname):
Expand Down
8 changes: 4 additions & 4 deletions ndsl/stencils/testing/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def setup(self, inputs) -> None:

def compute_func(self, **inputs) -> Optional[dict[str, Any]]:
"""Compute function to transform the dictionary of `inputs`.
Must return a dictionnary of updated variables"""
Must return a dictionary of updated variables"""
raise NotImplementedError("Implement a child class compute method")

def compute(self, inputs) -> dict[str, Any]:
"""Transform inputs from NetCDF to gt4py.storagers, run compute_func then slice
"""Transform inputs from NetCDF to gt4py.storages, run compute_func then slice
the outputs based on specifications.

Return: Dictonnary of storages reshaped for comparison
Return: Dictionary of storages reshaped for comparison
"""
self.setup(inputs)
return self.slice_output(self.compute_from_storage(inputs))
Expand Down Expand Up @@ -201,7 +201,7 @@ def collect_start_indices(self, datashape, varinfo):
def make_storage_data_input_vars(
self, inputs, storage_vars=None, dict_4d=True
) -> None:
"""From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionnary to update inputs to
"""From a set of raw inputs (straight from NetCDF), use the `in_vars` dictionary to update inputs to
their configured shape.

Return: None
Expand Down