diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 86bb5a24..45f2627c 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -359,9 +359,10 @@ def _parse_sdfg( return None with DaCeProgress(config, f"Parse code of {dace_program.name} to SDFG"): + closure = dace_program.__sdfg_closure__() sdfg = dace_program.to_sdfg( *args, - **dace_program.__sdfg_closure__(), + **closure, **kwargs, save=False, simplify=False, diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index a12c91bf..69c36eac 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -28,6 +28,7 @@ from ndsl.logging import ndsl_log from ndsl.quantity import Quantity from ndsl.quantity.field_bundle import FieldBundleType, MarkupFieldBundleType +from ndsl.quantity.tracer_bundle_type import MarkupTracerBundleType from ndsl.testing.comparison import LegacyMetric @@ -334,6 +335,11 @@ def __init__( types.name, do_markup=False ) + if isinstance(types, MarkupTracerBundleType): + raise NotImplementedError( + "TracerBundle markup types can't be resolved yet." + ) + self.stencil_object = gtscript.stencil( definition=func, externals=externals, @@ -379,9 +385,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: ndsl_log.debug(f"Running {self._func_name}") # Marshal arguments - args_list = list(args) - _convert_quantities_to_storage(args_list, kwargs) - args = tuple(args_list) + args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs) args_as_kwargs = dict(zip(self._argument_names, args)) # Ranks comparison tool @@ -524,25 +528,36 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t ) -def _convert_quantities_to_storage(args, kwargs): # type: ignore[no-untyped-def] - for i, arg in enumerate(args): - try: - # Check that 'dims' is an attribute of arg. If so, - # this means it's a Quantity, so we need - # to pull off the ndarray. - arg.dims - args[i] = arg.data - except AttributeError: - pass - for name, arg in kwargs.items(): - try: - # Check that 'dims' is an attribute of arg. If so, - # this means it's a Quantity, so we need - # to pull off the ndarray. - arg.dims - kwargs[name] = arg.data - except AttributeError: - pass +def _convert_NDSL_concepts_to_storage(args: tuple, kwargs: dict) -> tuple[tuple, dict]: + """Go through the list of args and kwargs to replace NDSL concepts. + + This function replaces NDSL concepts (like a `Quantity`) in the argument list + with things that a GT4Py stencils understands (e.g. ndarrays). + """ + + # The `args` tuple is immutable (all tuples are), so let's build a temporary + # argument list, such that we can replace arguments in that list. + arg_list = list(args) + + for index, argument in enumerate(args): + # if isinstance(argument, TracerBundle): + # # Reduce the TracerBundle to a Quantity (which is handled below) + # arg_list[index] = argument.data.data + + if isinstance(argument, Quantity): + # For Quantities, we need to pass on the underlying ndarray + arg_list[index] = argument.data + + for name, argument in kwargs.items(): + # if isinstance(argument, TracerBundle): + # # Reduce the TracerBundle to a Quantity (which is handled below) + # kwargs[name] = argument.data + + if isinstance(argument, Quantity): + # For Quantities, we need to pass on the underlying ndarray + kwargs[name] = argument.data + + return (tuple(arg_list), kwargs) class GridIndexing: diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index bffaf884..824150ee 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -37,6 +37,7 @@ def __init__( gt4py_backend: str | None = None, allow_mismatch_float_precision: bool = False, number_of_halo_points: int = 0, + raise_on_data_copy: bool = False, ): """Initialize a Quantity. @@ -54,6 +55,7 @@ def __init__( allow_mismatch_float_precision (bool, optional): allow for precision that is not the simulation-wide default configuration. Defaults to False. number_of_halo_points (int, optional): Number of halo points used. Defaults to 0. + raise_on_data_copy: raise if `data` is copied into this quantity Raises: ValueError: Data-type mismatch between configuration and input-data @@ -68,10 +70,13 @@ def __init__( f"Floating-point data type mismatch, asked for {data.dtype}, " f"Pace configured for {Float}" ) - if origin is None: - origin = (0,) * len(dims) # default origin at origin of array - else: - origin = tuple(origin) + + # Track if we copied data and if so, raise at the end in case + # `raise_on_data_copy` is True. The initialization errors on the safe side. + did_copy_data: bool = True + + # default origin at origin of array + origin = (0,) * len(dims) if origin is None else tuple(origin) if extent is None: extent = tuple(length - start for length, start in zip(data.shape, origin)) @@ -104,19 +109,20 @@ def __init__( ] ) - self._data = ( - data - if is_optimal_layout(data, dimensions) - else self._initialize_data( + if is_optimal_layout(data, dimensions): + self._data = data + did_copy_data = False + else: + self._data = self._initialize_data( data, origin=origin, gt4py_backend=gt4py_backend, dimensions=dimensions, ) - ) else: # We have no info about the gt4py_backend, so just assign it. self._data = data + did_copy_data = False _validate_quantity_property_lengths(data.shape, dims, origin, extent) self._metadata = QuantityMetadata( @@ -134,6 +140,11 @@ def __init__( self.data, self.dims, self.origin, self.extent ) + if raise_on_data_copy and did_copy_data: + raise RuntimeError( + "Data was copied into this quantity despite `raise_on_copy_data=True`." + ) + @classmethod def from_data_array( cls, diff --git a/ndsl/quantity/tracer_bundle.py b/ndsl/quantity/tracer_bundle.py new file mode 100644 index 00000000..790a3889 --- /dev/null +++ b/ndsl/quantity/tracer_bundle.py @@ -0,0 +1,194 @@ +import copy +from enum import Enum, auto +from typing import Any + +from dace import SDFG, SDFGState +from dace.data import create_datadescriptor +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python.newast import ProgramVisitor + +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.initialization.allocator import Quantity, QuantityFactory +from ndsl.quantity.tracer_bundle_type import TracerBundleTypeRegistry + + +@oprepo.replaces_method("ndsl.quantity.tracer_bundle.TracerBundle", "size") +def _tracer_bundle_fill_tracer( + pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any +) -> None: + raise NotImplementedError("let's just see if we get here") + + +@oprepo.replaces_method("ndsl.quantity.TracerBundle", "size") +def _tracer_bundle_fill_tracer_2( + pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any +) -> None: + raise NotImplementedError("let's just see if we get here 2") + + +@oprepo.replaces_method("tracers", "size") +def _tracer_bundle_fill_tracer_3( + pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any +) -> None: + raise NotImplementedError("let's just see if we get here 3") + + +@oprepo.replaces("fill_tracer_by_name") +def _fill_tracer_by_name( + pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args: Any, **kwargs: Any +) -> None: + bundle = args[0] + tracer_name = args[1] + fill_value = args[2] + + array_name = f"bundle_{bundle.type_name}" + if array_name not in sdfg.arrays: + sdfg.arrays[array_name] = create_datadescriptor(bundle.data.data) + + # insert tasklet to assign the value + + # connect tasklet. add missing inputs if necessary + + raise NotImplementedError("let's see if we get here") + + +class Region(Enum): + compute_domain = auto() + + +class Tracer(Quantity): + """A Tracer is a specialized Quantity, grouped together in a TracerBundle.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def fill(self, value: Any, *, restrict_to: Region | None = None) -> None: + if restrict_to is None: + super().data[:] = value + elif restrict_to is Region.compute_domain: + super().field[:] = value + else: + raise NotImplementedError(f"Unknown restriction {restrict_to}.") + + +_TracerName = str +_TracerIndex = int +_TracerMapping = dict[_TracerName, _TracerIndex] +_TracerDataMapping = dict[_TracerIndex, Tracer] + + +class TracerBundle: + """A TracerBundle groups a given set of named/nameless tracers into a single + four-dimensional Quantity. + + All tracers can be accessed by index, e.g. `tracer[1]`. Named tracers can be + accessed by name too, e.g. `tracer.vapor` assuming `vapor` is defined in the + `mapping` of names to tracer indices. `len(tracers)` returns the size of this + TracerBundle. + """ + + def __init__( + self, + *, + type_name: str, + quantity_factory: QuantityFactory, + mapping: _TracerMapping = {}, + unit: str = "g/kg", + ) -> None: + """ + Initialize a TracerBundle of a given size. + + Args: + type_name (str): name under which this bundle's type is registered. + quantity_factory: QuantityFactory to build tracers with. + mapping: Optional mapping of names to tracer ids, e.g. `{"vapor": 3}`. + unit: Optional unit of the tracers (one for all). + """ + types: Any = TracerBundleTypeRegistry.T(type_name, do_markup=False) + + size = types[0].data_dims[0] + factory = _tracer_quantity_factory(quantity_factory, size) + + # TODO: zeros() or empty()? should this be an option? + self.data = factory.zeros( + [X_DIM, Y_DIM, Z_DIM, "tracers"], dtype=types[0].dtype, units=unit + ) + self._size = size + self._name_mapping = mapping + self._data_mapping: _TracerDataMapping = {} + self.type_name = type_name + + def __len__(self) -> int: + """Number of tracers in this bundle.""" + return self._size + + def size(self) -> int: + return self._size + + def __getattr__(self, name: _TracerName) -> Tracer | None: + """Access tracers by name, e.g. `tracers.ice`.""" + + index = self._name_mapping.get(name, None) + if index is None: + # This replicates as close possible the default behavior of getattr + # without breaking orchestration + return None + + return self._by_index(index) + + def __getitem__(self, index: _TracerIndex) -> Tracer: + """Access tracers by index, e.g. `tracers[i]`.""" + return self._by_index(index) + + def _by_index(self, index: _TracerIndex) -> Tracer: + if index < 0 or index >= self._size: + # Note: it is important to raise an IndexError to support iterations of + # the form `for tracer in tracers`. + raise IndexError(f"You can only select tracers in range [0, {self._size}).") + + # Memoize tracers accessed such that we always return the same instance + # regardless of whether users access through __getattr__() or __getitem__(). + if index not in self._data_mapping: + self._data_mapping[index] = Tracer( + data=self.data.data[:, :, :, index], + dims=self.data.dims[:-1], + origin=self.data.origin[:-1], + extent=self.data.extent[:-1], + units=self.data.units, + # Ensure we never copy data into a tracer + raise_on_data_copy=True, + ) + + return self._data_mapping[index] + + def fill_tracer( + self, index: _TracerIndex, *, value: Any, compute_domain_only: bool = False + ) -> None: + if compute_domain_only: + self.data.field[:, :, :, index] = value + else: + self.data.data[:, :, :, index] = value + + def fill_tracer_by_name( + self, name: str, *, value: Any, compute_domain_only: bool = False + ) -> None: + index = self._name_mapping[name] + + if compute_domain_only: + self.data.field[:, :, :, index] = value + else: + self.data.data[:, :, :, index] = value + + +def _tracer_quantity_factory( + quantity_factory: QuantityFactory, number_of_tracers: int +) -> QuantityFactory: + """Create a tracer factory from a given cartesian quantity factory. + + Args: + quantity_factory: Cartesian 3D factory to start from. + number_of_tracers: number of tracers in this bundle. + """ + tracer_factory = copy.copy(quantity_factory) + tracer_factory.set_extra_dim_lengths(tracers=number_of_tracers) + return tracer_factory diff --git a/ndsl/quantity/tracer_bundle_type.py b/ndsl/quantity/tracer_bundle_type.py new file mode 100644 index 00000000..f517a7bc --- /dev/null +++ b/ndsl/quantity/tracer_bundle_type.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass +from typing import TypeAlias + +import dace +from gt4py.cartesian import gtscript + +from ndsl.dsl.typing import Float + + +BundleTypes: TypeAlias = tuple[gtscript._FieldDescriptor, dace.data.Structure] + + +@dataclass +class MarkupTracerBundleType: + """Markup a `TracerBundle` to delay specialization. + + Properties: + name: Name of the future type to retrieve from the registry. + """ + + name: str + + +class TracerBundleTypeRegistry: + """Class to register and retrieve TraceBundle types. + + Methods: + register: Register a type + T: access to any registered type for type hinting. + """ + + _type_registry: dict[str, BundleTypes] = {} + + @classmethod + def register(cls, name: str, *, size: int, dtype: type = Float) -> BundleTypes: + """Register a name type by name by giving the size of its data dimensions. + + The same type cannot be registered twice and will error out. + + Args: + name: Unique name for this `TracerBundle` type. + size: Number of tracers in the `TracerBundle`. + dtype: Data type of `TracerBundle`. + """ + if name in cls._type_registry: + # TODO: why do we do this? why not just warn out? + raise RuntimeError( + f"Names of `TracerBundle` types must be unique. `{name}` is already taken." + ) + + if " " in name: + raise ValueError("DaCe can't handle space in bundle names.") + + # TODO: do this properly + N = dace.symbol(f"{name}_item_size") + + cls._type_registry[name] = ( + gtscript.Field[gtscript.IJK, (dtype, ((size,)))], + dace.data.Structure( + members={ + # TODO: funnel dtype + # TODO: what to do about tracer size `N`? + "data": dace.data.Array(dace.float32, (size, N)), + "ice": dace.data.ArrayReference(dace.float32, (N,)), + "vapor": dace.data.ArrayReference(dace.float32, (N,)), + }, + name=name, + ), + ) + return cls._type_registry[name] + + @classmethod + def T( + cls, name: str, *, do_markup: bool = True + ) -> BundleTypes | MarkupTracerBundleType: + """ + Retrieve a previously registered type. + + Args: + name: name of the type as registered via `register` + do_markup: if name not registered, markup for a future specialization + at stencil call time + """ + if name not in cls._type_registry: + # Dev note: The markup feature is to allow early parsing (at file import) + # to go ahead - while we will resolve the full type when calling the stencil. + if do_markup: + return MarkupTracerBundleType(name) + + raise ValueError(f"TracerBundle type `{name}` has not been registered!") + + return cls._type_registry[name] diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index 5b8a49f8..42edf808 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -14,7 +14,7 @@ ) from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import _convert_quantities_to_storage +from ndsl.dsl.stencil import _convert_NDSL_concepts_to_storage from ndsl.dsl.typing import Float, FloatField @@ -326,7 +326,7 @@ def get_mock_quantity(): def test_convert_quantities_to_storage_no_args() -> None: args = [] kwargs = {} - _convert_quantities_to_storage(args, kwargs) + args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs) assert len(args) == 0 assert len(kwargs) == 0 @@ -335,7 +335,7 @@ def test_convert_quantities_to_storage_one_arg_quantity() -> None: quantity = get_mock_quantity() args = [quantity] kwargs = {} - _convert_quantities_to_storage(args, kwargs) + args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs) assert len(args) == 1 assert args[0] == quantity.data assert len(kwargs) == 0 @@ -345,7 +345,7 @@ def test_convert_quantities_to_storage_one_kwarg_quantity() -> None: quantity = get_mock_quantity() args = [] kwargs = {"val": quantity} - _convert_quantities_to_storage(args, kwargs) + args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs) assert len(args) == 0 assert len(kwargs) == 1 assert kwargs["val"] == quantity.data @@ -355,7 +355,7 @@ def test_convert_quantities_to_storage_one_arg_nonquantity() -> None: non_quantity = unittest.mock.MagicMock(spec=tuple) args = [non_quantity] kwargs = {} - _convert_quantities_to_storage(args, kwargs) + args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs) assert len(args) == 1 assert args[0] == non_quantity assert len(kwargs) == 0 @@ -365,7 +365,7 @@ def test_convert_quantities_to_storage_one_kwarg_non_quantity() -> None: non_quantity = unittest.mock.MagicMock(spec=tuple) args = [] kwargs = {"val": non_quantity} - _convert_quantities_to_storage(args, kwargs) + args, kwargs = _convert_NDSL_concepts_to_storage(args, kwargs) assert len(args) == 0 assert len(kwargs) == 1 assert kwargs["val"] == non_quantity diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index dccfa94f..7d1b9977 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -289,3 +289,14 @@ def test_data_setter(): # Expected fail: new array is not even an array with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): quantity.data = "meh" + + +def test_raise_on_data_copy_option(): + with pytest.raises(RuntimeError, match="Data was copied.*"): + Quantity( + np.random.randn(3, 2, 4), + dims=["dim1", "dim_2", "dims3"], + units="n/a", + gt4py_backend="dace:cpu", + raise_on_data_copy=True, + ) diff --git a/tests/quantity/test_tracer_bundle.py b/tests/quantity/test_tracer_bundle.py new file mode 100644 index 00000000..1f04fcd7 --- /dev/null +++ b/tests/quantity/test_tracer_bundle.py @@ -0,0 +1,105 @@ +"""This module includes unit tests for the `TracerBundle` class.""" + +import pytest + +from ndsl.boilerplate import get_factories_single_tile +from ndsl.quantity.tracer_bundle import Tracer, TracerBundle +from ndsl.quantity.tracer_bundle_type import TracerBundleTypeRegistry + + +_TRACER_BUNDLE_TYPENAME = "TracerBundleTypeUnitTests" +TracerBundleTypeRegistry.register(_TRACER_BUNDLE_TYPENAME, size=5) + + +def test_query_size_of_bundle_with_len() -> None: + _, quantity_factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, quantity_factory=quantity_factory + ) + + assert len(tracers) == 5 + + +def test_access_tracer_by_name() -> None: + _, factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=factory, + mapping={"ice": 3, "vapor": 1}, + ) + + assert isinstance(tracers.ice, Tracer) + assert isinstance(tracers.vapor, Tracer) + assert tracers.snow is None + + +def test_access_tracer_by_index() -> None: + _, factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=factory, + mapping={"ice": 3, "vapor": 1}, + ) + + assert isinstance(tracers[0], Tracer) + + with pytest.raises(IndexError, match=".*select tracers in range.*"): + tracers[len(tracers)] + + with pytest.raises(IndexError, match=".*select tracers in range.*"): + tracers[-1] + + +def test_same_tracer_by_name_and_index() -> None: + _, factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=factory, + mapping={"ice": 3, "vapor": 1}, + ) + + ice_tracer = tracers.ice + other_ice_tracer = tracers[3] + + assert ice_tracer is other_ice_tracer + + +def test_units_are_propagated_to_tracers() -> None: + _, factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + unit = "u" + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=factory, + unit=unit, + mapping={"ice": 3, "vapor": 1}, + ) + + ice_tracer = tracers.ice + other_ice_tracer = tracers[3] + + assert ice_tracer.units == unit + assert other_ice_tracer.units == unit + + +def test_loop_over_all_tracers_index() -> None: + _, factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=factory, + mapping={"ice": 3, "vapor": 1}, + ) + + for index in range(len(tracers)): + assert isinstance(tracers[index], Tracer) + + +def test_loop_over_all_tracers() -> None: + _, factory = get_factories_single_tile(nx=2, ny=3, nz=4, nhalo=1) + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=factory, + mapping={"ice": 3, "vapor": 1}, + ) + + for tracer in tracers: + assert isinstance(tracer, Tracer) diff --git a/tests/quantity/test_tracer_bundle_workflow.py b/tests/quantity/test_tracer_bundle_workflow.py new file mode 100644 index 00000000..95136d0b --- /dev/null +++ b/tests/quantity/test_tracer_bundle_workflow.py @@ -0,0 +1,240 @@ +"""This module includes integration tests for the `TracerBundle` class, +testing whole workflows.""" + +from typing import Any + +import pytest + +from ndsl import Quantity, StencilFactory, orchestrate +from ndsl.boilerplate import ( + get_factories_single_tile, + get_factories_single_tile_orchestrated, +) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import FloatField +from ndsl.quantity.tracer_bundle import TracerBundle +from ndsl.quantity.tracer_bundle_type import TracerBundleTypeRegistry + + +# workflow + +# 1. register tracer bundle type from (name, size, dtype) +# 2. initialize tracer bundle from (type_name, quantity_factory, [mapping, unit]) +# - size can be derived from registered type (via type_name) +# - dtype can be derived from registered type (via type_name) + +_TRACER_BUNDLE_TYPENAME = "TracerBundleTypeWorkflowTests" +_TracerBundleStencilType, _TracerBundleDaCeType = TracerBundleTypeRegistry.register( + _TRACER_BUNDLE_TYPENAME, size=5 +) + + +def fill_tracer_by_name( + bundle: TracerBundle, name: str, value: Any, *, write_halo: bool = False +) -> None: + bundle.fill_tracer_by_name(name, value=value, compute_domain_only=not write_halo) + + +class IceTracerSetup: + def __init__(self, stencil_factory: StencilFactory): + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + dace_compiletime_args=["tracers"], + ) + + def __call__(self, tracers: TracerBundle) -> None: + bla = tracers.size() + + # tracers.ice.data[:] = 20 + bla + fill_tracer_by_name(tracers, "ice", 20 + bla, write_halo=True) + # tracers.ice.field[:] = 10 + fill_tracer_by_name(tracers, "ice", 10) + + +@pytest.mark.parametrize("backend", ("numpy", "dace:cpu")) +def test_stencil_ice_tracer_setup(backend) -> None: + domain = (2, 3, 4) + halo_size = 1 + + stencil_factory, quantity_factory = get_factories_single_tile( + domain[0], domain[1], domain[2], halo_size, backend=backend + ) + + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=quantity_factory, + mapping={"ice": 3, "vapor": 0}, + ) + + setup = IceTracerSetup(stencil_factory) + setup(tracers) + + assert (tracers.ice.data[0] == 25).all() # check a part of the halo + assert (tracers.ice.field[:] == 10).all() # check the compute domain + + +def test_orchestrated_ice_tracer_setup() -> None: + domain = (2, 3, 4) + halo_size = 1 + + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + domain[0], + domain[1], + domain[2], + halo_size, + ) + + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=quantity_factory, + mapping={"ice": 3, "vapor": 0}, + ) + + setup = IceTracerSetup(stencil_factory) + setup(tracers) + + assert (tracers.ice.data[0] == 25).all() # check a part of the halo + assert (tracers.ice.field[:] == 10).all() # check the compute domain + + +def copy_into_tracer(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = in_field + + +class CopyIntoVaporTracer: + def __init__(self, stencil_factory: StencilFactory): + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + dace_compiletime_args=["tracers"], + ) + self._copy_into_tracer_stencil = stencil_factory.from_dims_halo( + func=copy_into_tracer, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, tracers: TracerBundle, vapor_field: Quantity): + tracers.vapor.data = 0 # __g_tracers_vapor_data + + self._copy_into_tracer_stencil(vapor_field, tracers.vapor) # __g_tracers_vapor + + +@pytest.mark.parametrize("backend", ("numpy", "dace:cpu")) +def test_stencil_copy_into_vapor_tracer(backend) -> None: + domain = (2, 3, 4) + halo_size = 1 + + stencil_factory, quantity_factory = get_factories_single_tile( + domain[0], domain[1], domain[2], halo_size, backend=backend + ) + + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=quantity_factory, + mapping={"ice": 3, "vapor": 0}, + ) + vapor_setup = CopyIntoVaporTracer(stencil_factory) + + field = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + vapor_setup(tracers, field) + assert (tracers.vapor.field[:] == 1).all() + + +def loop_over_tracers( + out_field: _TracerBundleStencilType, n_tracers: int, fill_value: int = 42 +): + with computation(PARALLEL), interval(...): + n = 0 + while n < n_tracers: + out_field[0, 0, 0][n] = fill_value + n = n + 1 + + +class ResetTracers: + def __init__(self, stencil_factory: StencilFactory): + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + dace_compiletime_args=["tracers"], + ) + self._loop_over_tracers_stencil = stencil_factory.from_dims_halo( + func=loop_over_tracers, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, tracers: TracerBundle, fill_value: int) -> None: + self._loop_over_tracers_stencil(tracers, len(tracers), fill_value) + + +@pytest.mark.parametrize( + "backend", + ( + "dace:cpu", + pytest.param( + "numpy", + marks=pytest.mark.xfail(reason="numpy backend cannot handle access of [n]"), + ), + ), +) +def test_stencil_reset_tracers(backend) -> None: + domain = (2, 3, 4) + halo_size = 1 + + stencil_factory, quantity_factory = get_factories_single_tile( + domain[0], domain[1], domain[2], halo_size, backend=backend + ) + + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=quantity_factory, + mapping={"ice": 3, "vapor": 0}, + ) + reset_tracers = ResetTracers(stencil_factory) + + fill_value = 42 + reset_tracers(tracers, fill_value) + for tracer in tracers: + assert (tracer.field[:] == fill_value).all() + + +class CopyIntoAllTracers: + def __init__(self, stencil_factory: StencilFactory): + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + dace_compiletime_args=["tracers"], + ) + self._copy_into_tracer_stencil = stencil_factory.from_dims_halo( + func=copy_into_tracer, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, tracers: TracerBundle, field: Quantity): + for tracer in tracers: + self._copy_into_tracer_stencil(field, tracer) + + +@pytest.mark.parametrize("backend", ("numpy", "dace:cpu")) +def test_stencil_copy_into_all_tracer(backend) -> None: + domain = (2, 3, 4) + halo_size = 1 + + stencil_factory, quantity_factory = get_factories_single_tile( + domain[0], domain[1], domain[2], halo_size, backend=backend + ) + + tracers = TracerBundle( + type_name=_TRACER_BUNDLE_TYPENAME, + quantity_factory=quantity_factory, + mapping={"ice": 3, "vapor": 0}, + ) + tracer_setup = CopyIntoAllTracers(stencil_factory) + + field = quantity_factory.ones(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") + tracer_setup(tracers, field) + + for tracer in tracers: + assert (tracer.field[:] == 1).all()