Skip to content
2 changes: 1 addition & 1 deletion examples/mpi/zarr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def get_example_state(time):
sizer = SubtileGridSizer(nx=48, ny=48, nz=70, n_halo=3, extra_dim_lengths={})
sizer = SubtileGridSizer(nx=48, ny=48, nz=70, n_halo=3, data_dimensions={})
allocator = QuantityFactory(sizer, np)
air_temperature = allocator.zeros([X_DIM, Y_DIM, Z_DIM], units="degK")
air_temperature.view[:] = np.random.randn(*air_temperature.extent)
Expand Down
2 changes: 1 addition & 1 deletion ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_factories(
ny_tile=ny,
nz=nz,
n_halo=nhalo,
extra_dim_lengths={},
data_dimensions={},
layout=partitioner.layout,
tile_partitioner=partitioner,
)
Expand Down
2 changes: 1 addition & 1 deletion ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def domain(self, domain: Index3D) -> None:
ny=domain[1],
nz=domain[2],
n_halo=self.n_halo,
extra_dim_lengths={},
data_dimensions={},
)

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions ndsl/grid/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def __init__(
self._tile_partitioner = self._comm.tile.partitioner
self._rank = self._comm.rank
self.quantity_factory = quantity_factory
self.quantity_factory.set_extra_dim_lengths(
**{
self.quantity_factory.add_data_dimensions(
{
self.LON_OR_LAT_DIM: 2,
self.TILE_DIM: 6,
self.CARTESIAN_DIM: 3,
Expand Down Expand Up @@ -500,7 +500,7 @@ def from_tile_sizing(
ny_tile=npy - 1,
nz=npz,
n_halo=N_HALO_DEFAULT,
extra_dim_lengths={
data_dimensions={
cls.LON_OR_LAT_DIM: 2,
cls.TILE_DIM: 6,
cls.CARTESIAN_DIM: 3,
Expand Down
41 changes: 40 additions & 1 deletion ndsl/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,46 @@ def set_extra_dim_lengths(self, **kwargs: Any) -> None:
"""
Set the length of extra (non-x/y/z) dimensions.
"""
self.sizer.extra_dim_lengths.update(kwargs)
warnings.warn(
"`QuantityFactory.set_extra_dim_lengths` is deprecated. "
"Use `add_data_dimensions` or `update_data_dimensions`.",
DeprecationWarning,
2,
)
self.sizer.data_dimensions.update(kwargs)

def update_data_dimensions(
self,
data_dimension_descriptions: dict[str, int],
) -> None:
"""
Update the length of data (non-x/y/z) dimensions, unknown data dimensions
will be added, existing ones updated.

Args:
data_dimension_descriptions: Dict of name/length pairs
"""
self.sizer.data_dimensions.update(data_dimension_descriptions)

def add_data_dimensions(
self,
data_dimension_descriptions: dict[str, int],
) -> None:
"""
Add new data (non-x/y/z) dimensions via a key-length pair. If the dimension
already exists, it will error out.

Args:
data_dimension_descriptions: Dict of name/length pairs
"""
for name in data_dimension_descriptions.keys():
if name in self.sizer.data_dimensions.keys():
raise ValueError(
f"[NDSL] Data dimension {name} already exists! "
"Use `update_data_dimensions` if you meant to update the length."
)

self.sizer.data_dimensions.update(data_dimension_descriptions)

@classmethod
def from_backend(cls, sizer: GridSizer, backend: str) -> QuantityFactory:
Expand Down
14 changes: 12 additions & 2 deletions ndsl/initialization/grid_sizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Sequence
from dataclasses import dataclass

Expand All @@ -12,8 +13,17 @@ class GridSizer:
"""Length of the z compute dimension for produced arrays."""
n_halo: int
"""Number of horizontal halo points for produced arrays."""
extra_dim_lengths: dict[str, int]
"""Lengths of any non-x/y/z dimensions, such as land or radiation dimensions."""
data_dimensions: dict[str, int]
"""Name/Lengths pair of any non-x/y/z dimensions, such as land or radiation dimensions."""

@property
def extra_dim_lengths(self) -> dict[str, int]:
warnings.warn(
"`GridSizer.extra_dim_lengths` is a deprecated API, use `GridSizer.data_dimensions`.",
DeprecationWarning,
2,
)
return self.data_dimensions

def get_origin(self, dims: Sequence[str]) -> tuple[int, ...]:
raise NotImplementedError()
Expand Down
37 changes: 24 additions & 13 deletions ndsl/initialization/subtile_grid_sizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Iterable
from typing import Self

Expand All @@ -15,10 +16,12 @@ def from_tile_params(
ny_tile: int,
nz: int,
n_halo: int,
extra_dim_lengths: dict[str, int],
layout: tuple[int, int],
*,
data_dimensions: dict[str, int] = {},

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a python anti-pattern to have an empty dict/list as default argument. The reason is that this will coupke function invocations because that empty dict might not be empty anymore once we get here in the next invocation. People usually use dict | None as type with None as default and the test for Non inside and assign a {} in that case (wich is then local to that call).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also do a follow-up, but basically, as far as I can see, the last commit, bcea590, wouldn't be necessary if we had

        data_dimensions: dict[str, int] | None = None,

in the argument list and then do some more fiddling inside the function. This is surprising, especially coming form C++ where default arguments behave differently. Here's a simple example of what goes wrong.

I think there's mypy or flake rules to catch those. I remember seeing them (probably in gt4py). I'll search tomorrow.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it's in flake8-bugbear

B006: Do not use mutable data structures for argument defaults. They are created during function definition time. All calls to the function reuse this one instance of that data structure, persisting changes between them.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow-up is here: #277

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to fix this is actually use a dataclasses.field(default_factory) on the base dataclass

tile_partitioner: TilePartitioner | None = None,
tile_rank: int = 0,
extra_dim_lengths: dict[str, int] | None = None,
) -> Self:
"""Create a SubtileGridSizer from parameters about the full tile.

Expand All @@ -27,13 +30,21 @@ def from_tile_params(
ny_tile: number of y cell centers on the tile
nz: number of vertical levels
n_halo: number of halo points
extra_dim_lengths: lengths of any non-x/y/z dimensions,
data_dimensions: lengths of any non-x/y/z dimensions,
such as land or radiation dimensions
layout: (y, x) number of ranks along tile edges
tile_partitioner (optional): partitioner object for the tile. By default, a
TilePartitioner is created with the given layout
tile_rank (optional): rank of this subtile.
extra_dim_lengths: DEPRECATED API - use `data_dimensions`
"""
if extra_dim_lengths is not None:
warnings.warn(
"`extra_dim_lengths` is a deprecated name, please `data_dimensions`.",
Comment thread
FlorianDeconinck marked this conversation as resolved.
Outdated
DeprecationWarning,
2,
)
data_dimensions = extra_dim_lengths
if tile_partitioner is None:
tile_partitioner = TilePartitioner(layout)
y_slice, x_slice = tile_partitioner.subtile_slice(
Expand All @@ -55,7 +66,7 @@ def from_tile_params(
"SubtileGridSizer::from_tile_params: Compute domain extent must be greater than halo size"
)

return cls(nx, ny, nz, n_halo, extra_dim_lengths)
return cls(nx, ny, nz, n_halo, data_dimensions)

@classmethod
def from_namelist(
Expand Down Expand Up @@ -92,19 +103,19 @@ def from_namelist(
"expected to find nx_tile or fv_core_nml"
)
return cls.from_tile_params(
nx_tile,
ny_tile,
nz,
N_HALO_DEFAULT,
{},
layout,
tile_partitioner,
tile_rank,
nx_tile=nx_tile,
ny_tile=ny_tile,
nz=nz,
n_halo=N_HALO_DEFAULT,
data_dimensions={},
layout=layout,
tile_partitioner=tile_partitioner,
tile_rank=tile_rank,
)

@property
def dim_extents(self) -> dict[str, int]:
return_dict = self.extra_dim_lengths.copy()
return_dict = self.data_dimensions.copy()
return_dict.update(
{
constants.X_DIM: self.nx,
Expand All @@ -128,7 +139,7 @@ def get_extent(self, dims: Iterable[str]) -> tuple[int, ...]:
return tuple(extents[dim] for dim in dims)

def get_shape(self, dims: Iterable[str]) -> tuple[int, ...]:
shape_dict = self.extra_dim_lengths.copy()
shape_dict = self.data_dimensions.copy()
# must pad non-interface variables to have the same shape as interface variables
shape_dict.update(
{
Expand Down
4 changes: 2 additions & 2 deletions ndsl/quantity/field_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def extend_3D_quantity_factory(
extra_dims: dict of [name, size] of the data dimensions to add.
"""
new_factory = copy.copy(quantity_factory)
new_factory.set_extra_dim_lengths(
**{
new_factory.add_data_dimensions(
{
**extra_dims,
}
)
Expand Down
6 changes: 3 additions & 3 deletions ndsl/quantity/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ def __init__(self, factory: QuantityFactory, ddims: dict[str, int]):
self._factory = factory

def __enter__(self) -> None:
self._original_dims = self._factory.sizer.extra_dim_lengths
self._factory.sizer.extra_dim_lengths = self._ddims
self._original_dims = self._factory.sizer.data_dimensions
self._factory.sizer.data_dimensions = self._ddims

def __exit__(
self,
type: type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self._factory.sizer.extra_dim_lengths = self._original_dims
self._factory.sizer.data_dimensions = self._original_dims

@classmethod
def empty(
Expand Down
2 changes: 1 addition & 1 deletion ndsl/stencils/testing/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def sizer(self):
ny_tile=self.npy - 1,
nz=self.npz,
n_halo=self.halo,
extra_dim_lengths={
data_dimensions={
MetricTerms.LON_OR_LAT_DIM: 2,
MetricTerms.TILE_DIM: 6,
MetricTerms.CARTESIAN_DIM: 3,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_4d_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __call__(self, q_in: FloatFieldTracer, q_out: FloatField):

def test_non_orchestrated_call() -> None:
stencil_factory, quantity_factory = get_factories_single_tile(24, 24, 91, 3)
quantity_factory.set_extra_dim_lengths(
**{
quantity_factory.add_data_dimensions(
{
TRACER_DIM: ntracers,
}
)
Expand All @@ -63,8 +63,8 @@ def test_orchestrated_call() -> None:
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
24, 24, 91, 3
)
quantity_factory.set_extra_dim_lengths(
**{
quantity_factory.add_data_dimensions(
{
TRACER_DIM: ntracers,
}
)
Expand Down
26 changes: 20 additions & 6 deletions tests/test_dimension_sizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def namelist(nx_tile, ny_tile, nz, layout):
def sizer(request, nx_tile, ny_tile, nz, layout, namelist, extra_dimension_lengths):
if request.param == "from_tile_params":
sizer = SubtileGridSizer.from_tile_params(
nx_tile,
ny_tile,
nz,
N_HALO_DEFAULT,
extra_dimension_lengths,
layout,
nx_tile=nx_tile,
ny_tile=ny_tile,
nz=nz,
n_halo=N_HALO_DEFAULT,
layout=layout,
data_dimensions=extra_dimension_lengths,
)
elif request.param == "from_namelist":
sizer = SubtileGridSizer.from_namelist(namelist)
Expand Down Expand Up @@ -220,3 +220,17 @@ def test_allocator_empty(sizer, dim_case, units, dtype):
assert quantity.origin == dim_case.origin
assert quantity.extent == dim_case.extent
assert quantity.data.shape == dim_case.shape


def test_allocator_data_dimensions_operations(sizer):
quantity_factory = QuantityFactory.from_backend(sizer, "numpy")
quantity_factory.add_data_dimensions({"D0": 11})
assert "D0" in quantity_factory.sizer.data_dimensions.keys()
assert quantity_factory.sizer.data_dimensions["D0"] == 11
quantity_factory.update_data_dimensions({"D0": 22})
assert quantity_factory.sizer.data_dimensions["D0"] == 22
with pytest.raises(
ValueError,
match="Use `update_data_dimensions` if you meant to update the length.",
):
quantity_factory.add_data_dimensions({"D0": 33})
Loading