diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index a12c91bf..57d6d738 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -3,6 +3,7 @@ import copy import dataclasses import inspect +import numbers from collections.abc import Callable, Iterable, Mapping, Sequence from typing import Any, cast @@ -378,6 +379,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: if self.stencil_config.verbose: ndsl_log.debug(f"Running {self._func_name}") + self._validate_quantity_sizes(*args, **kwargs) + # Marshal arguments args_list = list(args) _convert_quantities_to_storage(args_list, kwargs) @@ -523,6 +526,39 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t constant_args, given_args, parent_closure=parent_closure ) + def _validate_quantity_sizes(self, *args, **kwargs): # type: ignore[no-untyped-def] + """Checks that the sizes of quantities are compatible with the domain of the stencil. + + This function emits a warning in case one of the dimensions does not match. + + """ + all_args_as_kwargs = dict(zip(self._argument_names, tuple(list(args)))) | kwargs + + domain_sizes = { + axis_name: axis_size + for axis_names, axis_size in zip([X_DIMS, Y_DIMS, Z_DIMS], self.domain) + for axis_name in axis_names + } + + for name, argument in all_args_as_kwargs.items(): + if isinstance(argument, Quantity): + for axis, quantity_size in zip(argument.dims, argument.extent): + full_size = quantity_size + if axis in (X_DIMS + Y_DIMS): + full_size += 2 * argument.metadata.n_halo + if ( + axis in (X_DIMS + Y_DIMS + Z_DIMS) + and full_size < domain_sizes[axis] + ): + ndsl_log.warning( + f"Quantity `{name}` is too small for the targeted " + f"domain in axis {axis}: {full_size} < {domain_sizes[axis]}." + ) + elif not isinstance(argument, numbers.Real): + ndsl_log.warning( + f"Found an array-type argument {name} that is not a Quantity. Some domain-size checks are omitted." + ) + def _convert_quantities_to_storage(args, kwargs): # type: ignore[no-untyped-def] for i, arg in enumerate(args): diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 28db6a79..9b29fa5b 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -1,8 +1,19 @@ +from unittest.mock import MagicMock, patch + +import numpy as np import pytest from gt4py.storage import empty, ones -from ndsl import CompilationConfig, GridIndexing, StencilConfig, StencilFactory +from ndsl import ( + CompilationConfig, + FrozenStencil, + GridIndexing, + StencilConfig, + StencilFactory, +) from ndsl.dsl.gt4py import PARALLEL, Field, computation, interval +from ndsl.dsl.typing import FloatField +from ndsl.quantity import Quantity from tests.dsl import utils @@ -61,3 +72,43 @@ def test_grid_indexing_get_2d_compute_origin_domain( assert origin[2] == expected_origin_k assert domain[2] == 1 + + +def copy_stencil(q_in: FloatField, q_out: FloatField): # type: ignore + with computation(PARALLEL), interval(...): + q_out[0, 0, 0] = q_in + + +@pytest.mark.parametrize( + "extent,dimensions,domain,call_count", + [ + ((20, 20, 30), ["x", "y", "z"], (20, 20, 20), 0), + ((20, 20), ["x", "y"], (20, 20, 30), 0), + ((20, 20), ["x_interface", "y"], (20, 20, 30), 0), + ((20, 20), ["x", "y_interface"], (20, 20, 30), 0), + ((20,), ["z"], (20, 20, 10), 0), + ((20,), ["z_interface"], (20, 20, 10), 0), + ((15, 20, 30), ["x", "y", "z"], (20, 20, 30), 1), + ((20, 15, 30), ["x", "y", "z"], (20, 20, 30), 1), + ((20, 20, 15), ["x", "y", "z"], (20, 20, 30), 1), + ], +) +def test_domain_size_comparison( + extent: tuple[int], + dimensions: list[str], + domain: tuple[int], + call_count: int, +): + quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent) + stencil = FrozenStencil( + copy_stencil, + origin=(0, 0, 0), + domain=domain, + stencil_config=MagicMock(spec=StencilConfig()), + ) + # with expectation: + warning_mock = MagicMock() + with patch("ndsl.ndsl_log.warning", warning_mock): + stencil._validate_quantity_sizes(quantity) + + assert warning_mock.call_count == call_count