diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 675610e3..6771b99a 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,4 +1,5 @@ from . import dsl # isort:skip +from .initialization.allocator import QuantityFactory # isort:skip from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import LocalComm from .comm.mpi import MPIComm @@ -20,7 +21,6 @@ from .exceptions import OutOfBoundsError from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log from .monitor.netcdf_monitor import NetCDFMonitor diff --git a/ndsl/dsl/dace/build.py b/ndsl/dsl/dace/build.py index ba256e38..32e746ea 100644 --- a/ndsl/dsl/dace/build.py +++ b/ndsl/dsl/dace/build.py @@ -103,7 +103,7 @@ def get_sdfg_path( f"cannot be run with current resolution {config.tile_resolution}" ) - print(f"[DaCe Config] Rank {config.my_rank} loading SDFG {sdfg_dir_path}") + ndsl_log.debug(f"[DaCe Config] Rank {config.my_rank} loading SDFG {sdfg_dir_path}") return sdfg_dir_path diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 4e920664..5ffbc702 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -9,6 +9,8 @@ from mpi4py import MPI from numpy.typing import ArrayLike +from ndsl.types import Number + if TYPE_CHECKING: from ndsl import QuantityFactory @@ -66,34 +68,129 @@ def _init_recursive(cls): dict_of_quantities = _init_recursive(cls) return dacite.from_dict(data_class=cls, data=dict_of_quantities) + class _FactorySwapDimensionsDefinitions: + """INTERNAL: QuantityFactory carry a sizer which has a full definition of the dimensions. + It's this sizer that is leveraged for the factory to figure out allocations. + In a regular pattern of use, data dimensions fields tend to be _the exception_ rather + than the rule and therefor would need a Factory defined _for a few cases_. + We bring this tool to override temporarly the allocations based on a single descriptions of + the data dimenions at allocation time. + """ + + def __init__(self, factory: QuantityFactory, ddims: dict[str, int]): + self._ddims = ddims + self._factory = factory + + def __enter__(self): + self._original_dims = self._factory.sizer.extra_dim_lengths + self._factory.sizer.extra_dim_lengths = self._ddims + + def __exit__(self, type, value, traceback): + self._factory.sizer.extra_dim_lengths = self._original_dims + @classmethod - def empty(cls, quantity_factory: QuantityFactory) -> Self: - """Allocate all quantities""" + def empty( + cls, + quantity_factory: QuantityFactory, + *, + data_dimensions: dict[str, int] = {}, + ) -> Self: + """Allocate all quantities. Do not expect 0 on values, values are random. - return cls._init(quantity_factory.empty) + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.empty) + return state @classmethod - def zeros(cls, quantity_factory: QuantityFactory) -> Self: - """Allocate all quantities and fill their value to zeros""" + def zeros( + cls, + quantity_factory: QuantityFactory, + *, + data_dimensions: dict[str, int] = {}, + ) -> Self: + """Allocate all quantities and fill their value to zeros + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.zeros) + return state + + @classmethod + def ones( + cls, + quantity_factory: QuantityFactory, + *, + data_dimensions: dict[str, int] = {}, + ) -> Self: + """Allocate all quantities and fill their value to ones - return cls._init(quantity_factory.zeros) + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.ones) + return state @classmethod - def ones(cls, quantity_factory: QuantityFactory) -> Self: - """Allocate all quantities and fill their value to ones""" + def full( + cls, + quantity_factory: QuantityFactory, + value: Number, + *, + data_dimensions: dict[str, int] = {}, + ) -> Self: + """Allocate all quantities and fill them with the input value - return cls._init(quantity_factory.ones) + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + value: number to initialize the buffers with. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ + + with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions): + state = cls._init(quantity_factory.empty) + state.fill(value) + return state @classmethod def copy_memory( cls, quantity_factory: QuantityFactory, memory_map: StateMemoryMapping, + *, + data_dimensions: dict[str, int] = {}, ) -> Self: """Allocate all quantities and fill their value based - on the given memory map. See `update_from_memory`""" + on the given memory map. See `update_from_memory`. + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + memory_map: Dict of name/buffer. See `update_from_memory`. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + """ - state = cls.zeros(quantity_factory) + state = cls.zeros(quantity_factory, data_dimensions=data_dimensions) state.update_copy_memory(memory_map) return state @@ -104,12 +201,23 @@ def move_memory( quantity_factory: QuantityFactory, memory_map: StateMemoryMapping, *, + data_dimensions: dict[str, int] = {}, check_shape_and_strides: bool = True, ) -> Self: """Allocate all quantities and move memory based on - on the given memory map. See `update_move_memory`.""" + on the given memory map. See `update_move_memory`. + + Args: + quantity_factory: factory, expected to be defined on the Grid dimensions + e.g. without data dimensions. + memory_map: Dict of name/buffer. See `update_from_memory`. + data_dimensions: extra data dimensions required for any field with data dimensions. + Dict of name/size pair. + check_shape_and_strides: Check for evey given buffer that the shape & strides match the + previously allocated memory. + """ - state = cls.zeros(quantity_factory) + state = cls.zeros(quantity_factory, data_dimensions=data_dimensions) state.update_move_memory( memory_map, check_shape_and_strides=check_shape_and_strides, @@ -117,6 +225,19 @@ def move_memory( return state + def fill(self, value: Number) -> None: + def _fill_recursive( + state: State, + value: Number, + ) -> None: + for _field in dataclasses.fields(state): + if dataclasses.is_dataclass(_field.type): + _fill_recursive(state.__getattribute__(_field.name), value) + else: + state.__getattribute__(_field.name).field[:] = value + + _fill_recursive(self, value) + def update_copy_memory(self, memory_map: dict[str, Any]) -> None: """Copy data into the Quantities carried by the state. diff --git a/ndsl/types.py b/ndsl/types.py index e3461c39..10592e32 100644 --- a/ndsl/types.py +++ b/ndsl/types.py @@ -1,11 +1,11 @@ import functools -from typing import Iterable, TypeVar +from typing import Iterable, TypeAlias import numpy as np from typing_extensions import Protocol -Array = TypeVar("Array") +Number: TypeAlias = int | float | np.int32 | np.int64 | np.float32 | np.float64 class Allocator(Protocol): @@ -14,7 +14,6 @@ def __call__(self, shape: Iterable[int], dtype: type): class NumpyModule(Protocol): - empty: Allocator zeros: Allocator ones: Allocator diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py index 97ba6262..4c3cd1b5 100644 --- a/tests/quantity/test_state.py +++ b/tests/quantity/test_state.py @@ -47,17 +47,23 @@ class InnerB: ) -def test_state(tmpdir): +def test_basic_state(tmpdir): _, quantity_factory = get_factories_single_tile( 5, 5, 3, 0, backend="dace:cpu_kfirst" ) - microphys_state = CodeState.zeros(quantity_factory) + # Test allocator + microphys_state = CodeState.ones(quantity_factory) + assert (microphys_state.inner_A.A.field[:] == 1.0).all() + + # Test NetCDF round trip microphys_state.inner_A.A.field[:] = 42.42 microphys_state.to_netcdf(Path(tmpdir)) microphys_state2 = CodeState.zeros(quantity_factory) microphys_state2.update_from_netcdf(Path(tmpdir)) assert (microphys_state2.inner_A.A.field[:] == 42.42).all() + + # Test memory move (no copy) a = np.ones((5, 5, 3)) b = np.ones((5, 5, 3)) c = np.ones((5, 5, 3)) @@ -68,3 +74,73 @@ def test_state(tmpdir): ) assert (microphys_state2.inner_A.A.field[:] == 1.0).all() assert (microphys_state2.inner_B.B.field[:] == 23.23).all() + + # Test fill + microphys_state2.fill(18.18) + assert (microphys_state2.inner_A.A.field[:] == 18.18).all() + assert (microphys_state2.inner_B.B.field[:] == 18.18).all() + + # Test full + microphys_state3 = CodeState.full(quantity_factory, 90.90) + assert (microphys_state3.inner_A.A.field[:] == 90.90).all() + assert (microphys_state3.inner_B.B.field[:] == 90.90).all() + assert (microphys_state3.C.field[:] == 90.90).all() + + +@dataclasses.dataclass +class CodeStateWithDDim(State): + @dataclasses.dataclass + class InnerA: + ddim_A: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM, "ExtraDim1"], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + @dataclasses.dataclass + class InnerB: + ddim_B: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM, "ExtraDim2"], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + inner_A: InnerA + inner_B: InnerB + no_ddim: Quantity = dataclasses.field( + metadata={ + "name": "C", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +def test_state_ddim(): + _, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + + # Test allocator + microphys_state = CodeStateWithDDim.ones( + quantity_factory, + data_dimensions={ + "ExtraDim1": 3, + "ExtraDim2": 4, + }, + ) + assert (microphys_state.no_ddim.field[:] == 1.0).all() + assert microphys_state.inner_A.ddim_A.field.shape == (5, 5, 3, 3) + assert (microphys_state.inner_A.ddim_A.field[:] == 1.0).all() + assert microphys_state.inner_B.ddim_B.field.shape == (5, 5, 3, 4) + assert (microphys_state.inner_B.ddim_B.field[:] == 1.0).all()