Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ndsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import dsl # isort:skip
from .initialization.allocator import QuantityFactory # isort:skip
from .comm.communicator import CubedSphereCommunicator, TileCommunicator
from .comm.local_comm import LocalComm
from .comm.mpi import MPIComm
Expand All @@ -20,7 +21,6 @@
from .exceptions import OutOfBoundsError
from .halo.data_transformer import HaloExchangeSpec
from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater
from .initialization.allocator import QuantityFactory
from .initialization.sizer import GridSizer, SubtileGridSizer
from .logging import ndsl_log
from .monitor.netcdf_monitor import NetCDFMonitor
Expand Down
2 changes: 1 addition & 1 deletion ndsl/dsl/dace/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_sdfg_path(
f"cannot be run with current resolution {config.tile_resolution}"
)

print(f"[DaCe Config] Rank {config.my_rank} loading SDFG {sdfg_dir_path}")
ndsl_log.debug(f"[DaCe Config] Rank {config.my_rank} loading SDFG {sdfg_dir_path}")

return sdfg_dir_path

Expand Down
147 changes: 134 additions & 13 deletions ndsl/quantity/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from mpi4py import MPI
from numpy.typing import ArrayLike

from ndsl.types import Number


if TYPE_CHECKING:
from ndsl import QuantityFactory
Expand Down Expand Up @@ -66,34 +68,129 @@ def _init_recursive(cls):
dict_of_quantities = _init_recursive(cls)
return dacite.from_dict(data_class=cls, data=dict_of_quantities)

class _FactorySwapDimensionsDefinitions:
"""INTERNAL: QuantityFactory carry a sizer which has a full definition of the dimensions.
It's this sizer that is leveraged for the factory to figure out allocations.
In a regular pattern of use, data dimensions fields tend to be _the exception_ rather
than the rule and therefor would need a Factory defined _for a few cases_.
We bring this tool to override temporarly the allocations based on a single descriptions of
the data dimenions at allocation time.
"""

def __init__(self, factory: QuantityFactory, ddims: dict[str, int]):
self._ddims = ddims
self._factory = factory

def __enter__(self):
self._original_dims = self._factory.sizer.extra_dim_lengths
self._factory.sizer.extra_dim_lengths = self._ddims

def __exit__(self, type, value, traceback):
self._factory.sizer.extra_dim_lengths = self._original_dims

@classmethod
def empty(cls, quantity_factory: QuantityFactory) -> Self:
"""Allocate all quantities"""
def empty(
cls,
quantity_factory: QuantityFactory,
*,
data_dimensions: dict[str, int] = {},
) -> Self:
"""Allocate all quantities. Do not expect 0 on values, values are random.

return cls._init(quantity_factory.empty)
Args:
quantity_factory: factory, expected to be defined on the Grid dimensions
e.g. without data dimensions.
data_dimensions: extra data dimensions required for any field with data dimensions.
Dict of name/size pair.
"""

with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions):
state = cls._init(quantity_factory.empty)
return state

@classmethod
def zeros(cls, quantity_factory: QuantityFactory) -> Self:
"""Allocate all quantities and fill their value to zeros"""
def zeros(
cls,
quantity_factory: QuantityFactory,
*,
data_dimensions: dict[str, int] = {},
) -> Self:
"""Allocate all quantities and fill their value to zeros

Args:
quantity_factory: factory, expected to be defined on the Grid dimensions
e.g. without data dimensions.
data_dimensions: extra data dimensions required for any field with data dimensions.
Dict of name/size pair.
"""

with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions):
state = cls._init(quantity_factory.zeros)
return state

@classmethod
def ones(
cls,
quantity_factory: QuantityFactory,
*,
data_dimensions: dict[str, int] = {},
) -> Self:
"""Allocate all quantities and fill their value to ones

return cls._init(quantity_factory.zeros)
Args:
quantity_factory: factory, expected to be defined on the Grid dimensions
e.g. without data dimensions.
data_dimensions: extra data dimensions required for any field with data dimensions.
Dict of name/size pair.
"""

with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions):
state = cls._init(quantity_factory.ones)
return state

@classmethod
def ones(cls, quantity_factory: QuantityFactory) -> Self:
"""Allocate all quantities and fill their value to ones"""
def full(
cls,
quantity_factory: QuantityFactory,
value: Number,
*,
data_dimensions: dict[str, int] = {},
) -> Self:
"""Allocate all quantities and fill them with the input value

return cls._init(quantity_factory.ones)
Args:
quantity_factory: factory, expected to be defined on the Grid dimensions
e.g. without data dimensions.
value: number to initialize the buffers with.
data_dimensions: extra data dimensions required for any field with data dimensions.
Dict of name/size pair.
"""

with State._FactorySwapDimensionsDefinitions(quantity_factory, data_dimensions):
state = cls._init(quantity_factory.empty)
state.fill(value)
return state

@classmethod
def copy_memory(
cls,
quantity_factory: QuantityFactory,
memory_map: StateMemoryMapping,
*,
data_dimensions: dict[str, int] = {},
) -> Self:
"""Allocate all quantities and fill their value based
on the given memory map. See `update_from_memory`"""
on the given memory map. See `update_from_memory`.

Args:
quantity_factory: factory, expected to be defined on the Grid dimensions
e.g. without data dimensions.
memory_map: Dict of name/buffer. See `update_from_memory`.
data_dimensions: extra data dimensions required for any field with data dimensions.
Dict of name/size pair.
"""

state = cls.zeros(quantity_factory)
state = cls.zeros(quantity_factory, data_dimensions=data_dimensions)
state.update_copy_memory(memory_map)

return state
Expand All @@ -104,19 +201,43 @@ def move_memory(
quantity_factory: QuantityFactory,
memory_map: StateMemoryMapping,
*,
data_dimensions: dict[str, int] = {},
check_shape_and_strides: bool = True,
) -> Self:
"""Allocate all quantities and move memory based on
on the given memory map. See `update_move_memory`."""
on the given memory map. See `update_move_memory`.

Args:
quantity_factory: factory, expected to be defined on the Grid dimensions
e.g. without data dimensions.
memory_map: Dict of name/buffer. See `update_from_memory`.
data_dimensions: extra data dimensions required for any field with data dimensions.
Dict of name/size pair.
check_shape_and_strides: Check for evey given buffer that the shape & strides match the
previously allocated memory.
"""

state = cls.zeros(quantity_factory)
state = cls.zeros(quantity_factory, data_dimensions=data_dimensions)
state.update_move_memory(
memory_map,
check_shape_and_strides=check_shape_and_strides,
)

return state

def fill(self, value: Number) -> None:
def _fill_recursive(
state: State,
value: Number,
) -> None:
for _field in dataclasses.fields(state):
if dataclasses.is_dataclass(_field.type):
_fill_recursive(state.__getattribute__(_field.name), value)
else:
state.__getattribute__(_field.name).field[:] = value

_fill_recursive(self, value)

def update_copy_memory(self, memory_map: dict[str, Any]) -> None:
"""Copy data into the Quantities carried by the state.

Expand Down
5 changes: 2 additions & 3 deletions ndsl/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import functools
from typing import Iterable, TypeVar
from typing import Iterable, TypeAlias

import numpy as np
from typing_extensions import Protocol


Array = TypeVar("Array")
Number: TypeAlias = int | float | np.int32 | np.int64 | np.float32 | np.float64


class Allocator(Protocol):
Expand All @@ -14,7 +14,6 @@ def __call__(self, shape: Iterable[int], dtype: type):


class NumpyModule(Protocol):

empty: Allocator
zeros: Allocator
ones: Allocator
Expand Down
80 changes: 78 additions & 2 deletions tests/quantity/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,23 @@ class InnerB:
)


def test_state(tmpdir):
def test_basic_state(tmpdir):
_, quantity_factory = get_factories_single_tile(
5, 5, 3, 0, backend="dace:cpu_kfirst"
)

microphys_state = CodeState.zeros(quantity_factory)
# Test allocator
microphys_state = CodeState.ones(quantity_factory)
assert (microphys_state.inner_A.A.field[:] == 1.0).all()

# Test NetCDF round trip
microphys_state.inner_A.A.field[:] = 42.42
microphys_state.to_netcdf(Path(tmpdir))
microphys_state2 = CodeState.zeros(quantity_factory)
microphys_state2.update_from_netcdf(Path(tmpdir))
assert (microphys_state2.inner_A.A.field[:] == 42.42).all()

# Test memory move (no copy)
a = np.ones((5, 5, 3))
b = np.ones((5, 5, 3))
c = np.ones((5, 5, 3))
Expand All @@ -68,3 +74,73 @@ def test_state(tmpdir):
)
assert (microphys_state2.inner_A.A.field[:] == 1.0).all()
assert (microphys_state2.inner_B.B.field[:] == 23.23).all()

# Test fill
microphys_state2.fill(18.18)
assert (microphys_state2.inner_A.A.field[:] == 18.18).all()
assert (microphys_state2.inner_B.B.field[:] == 18.18).all()

# Test full
microphys_state3 = CodeState.full(quantity_factory, 90.90)
assert (microphys_state3.inner_A.A.field[:] == 90.90).all()
assert (microphys_state3.inner_B.B.field[:] == 90.90).all()
assert (microphys_state3.C.field[:] == 90.90).all()


@dataclasses.dataclass
class CodeStateWithDDim(State):
@dataclasses.dataclass
class InnerA:
ddim_A: Quantity = dataclasses.field(
metadata={
"name": "A",
"dims": [X_DIM, Y_DIM, Z_DIM, "ExtraDim1"],
"units": "kg kg-1",
"intent": "?",
"dtype": Float,
}
)

@dataclasses.dataclass
class InnerB:
ddim_B: Quantity = dataclasses.field(
metadata={
"name": "A",
"dims": [X_DIM, Y_DIM, Z_DIM, "ExtraDim2"],
"units": "kg kg-1",
"intent": "?",
"dtype": Float,
}
)

inner_A: InnerA
inner_B: InnerB
no_ddim: Quantity = dataclasses.field(
metadata={
"name": "C",
"dims": [X_DIM, Y_DIM, Z_DIM],
"units": "kg kg-1",
"intent": "?",
"dtype": Float,
}
)


def test_state_ddim():
_, quantity_factory = get_factories_single_tile(
5, 5, 3, 0, backend="dace:cpu_kfirst"
)

# Test allocator
microphys_state = CodeStateWithDDim.ones(
quantity_factory,
data_dimensions={
"ExtraDim1": 3,
"ExtraDim2": 4,
},
)
assert (microphys_state.no_ddim.field[:] == 1.0).all()
assert microphys_state.inner_A.ddim_A.field.shape == (5, 5, 3, 3)
assert (microphys_state.inner_A.ddim_A.field[:] == 1.0).all()
assert microphys_state.inner_B.ddim_B.field.shape == (5, 5, 3, 4)
assert (microphys_state.inner_B.ddim_B.field[:] == 1.0).all()
Loading