Skip to content
38 changes: 38 additions & 0 deletions ndsl/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ndsl.constants import SPATIAL_DIMS
from ndsl.dsl.typing import Float
from ndsl.initialization import GridSizer
from ndsl.logging import ndsl_log_on_rank_0
from ndsl.quantity import Quantity, QuantityHaloSpec


Expand Down Expand Up @@ -52,8 +53,45 @@ def set_extra_dim_lengths(self, **kwargs: Any) -> None:
"""
Set the length of extra (non-x/y/z) dimensions.
"""
ndsl_log_on_rank_0.warning(
"`QuantityFactory.set_extra_dim_lengths` is deprecated. "
"Use `add_data_dimensions` or `update_data_dimensions`.",
)
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
self.sizer.extra_dim_lengths.update(kwargs)

def update_data_dimensions(
self,
data_dimension_descriptions: dict[str, int],
) -> None:
"""
Update the length of extra (non-x/y/z) dimensions, unknown data dimensions
will be added, existing updated.
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

Args:
data_dimension_descriptions: Dict of name/length pairs
"""
self.sizer.extra_dim_lengths.update(data_dimension_descriptions)
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated

def add_data_dimensions(
self,
data_dimension_descriptions: dict[str, int],
) -> None:
"""
Add new data (non-x/y/z) dimensions via a key-length pair. If the dimension
already exists, it will error out.

Args:
data_dimension_descriptions: Dict of name/length pairs
"""
for name in data_dimension_descriptions.keys():
if name in self.sizer.extra_dim_lengths.keys():
raise ValueError(
f"[NDSL] Data dimension {name} already exists! "
"Use `update_data_dimensions` if you meant to update the length."
)

self.sizer.extra_dim_lengths.update(data_dimension_descriptions)

@classmethod
def from_backend(cls, sizer: GridSizer, backend: str) -> QuantityFactory:
"""Initialize a QuantityFactory to use a specific gt4py backend.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_4d_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __call__(self, q_in: FloatFieldTracer, q_out: FloatField):

def test_non_orchestrated_call() -> None:
stencil_factory, quantity_factory = get_factories_single_tile(24, 24, 91, 3)
quantity_factory.set_extra_dim_lengths(
**{
quantity_factory.add_data_dimensions(
{
TRACER_DIM: ntracers,
}
)
Expand All @@ -63,8 +63,8 @@ def test_orchestrated_call() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
24, 24, 91, 3
)
quantity_factory.set_extra_dim_lengths(
**{
quantity_factory.add_data_dimensions(
{
TRACER_DIM: ntracers,
}
)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_dimension_sizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,17 @@ def test_allocator_empty(sizer, dim_case, units, dtype):
assert quantity.origin == dim_case.origin
assert quantity.extent == dim_case.extent
assert quantity.data.shape == dim_case.shape


def test_allocator_data_dimensions_operations(sizer):
quantity_factory = QuantityFactory.from_backend(sizer, "numpy")
quantity_factory.add_data_dimensions({"D0": 11})
assert "D0" in quantity_factory.sizer.extra_dim_lengths.keys()
assert quantity_factory.sizer.extra_dim_lengths["D0"] == 11
quantity_factory.update_data_dimensions({"D0": 22})
assert quantity_factory.sizer.extra_dim_lengths["D0"] == 22
with pytest.raises(
ValueError,
match="Use `update_data_dimensions` if you meant to update the length.",
):
quantity_factory.add_data_dimensions({"D0": 33})