Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None): #
dims=quantity.dims,
units=quantity.units,
allow_mismatch_float_precision=True,
gt4py_backend=quantity.metadata.gt4py_backend,
)
if recv_state is not None and name in recv_state:
tile_quantity = self.gather(
Expand Down
2 changes: 2 additions & 0 deletions ndsl/monitor/netcdf_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, initial: Quantity, time_chunk_size: int):
self._dims = initial.dims
self._units = initial.units
self._i_time = 1
self._backend = initial.metadata.gt4py_backend

def append(self, quantity: Quantity) -> None:
# Allow mismatch precision here since this is I/O
Expand All @@ -37,6 +38,7 @@ def data(self) -> Quantity:
dims=("time",) + tuple(self._dims),
units=self._units,
allow_mismatch_float_precision=True,
gt4py_backend=self._backend,
)


Expand Down
37 changes: 29 additions & 8 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
dims (Sequence[str]): dimension names for each axis
units (str): units of the quantity
origin (Sequence[int] | None, optional): first point in data within the
computational domain. Defaults to None.
computational domain. Defaults to None.
extent (Sequence[int] | None, optional): number of points along each axis
within the computational domain. Defaults to None.
gt4py_backend (str | None, optional): backend to use for gt4py storages,
Expand All @@ -59,6 +59,13 @@ def __init__(
ValueError: Data-type mismatch between configuration and input-data
TypeError: Typing of the data that does not fit
"""
if gt4py_backend is None:
warnings.warn(
"gt4py_backend will be mandatory in future releases.",
DeprecationWarning,
stacklevel=2,
)

if (
not allow_mismatch_float_precision
and is_float(data.dtype)
Expand All @@ -84,9 +91,11 @@ def __init__(

if not isinstance(data, (np.ndarray, cupy.ndarray)):
raise TypeError(
f"Only supports numpy.ndarray and cupy.ndarray, got {type(data)}"
f"Only supports numpy.ndarray and cupy.ndarray, got {type(data)}."
)

_validate_quantity_property_lengths(data.shape, dims, origin, extent)

if gt4py_backend is not None:
gt4py_backend_cls = gt_backend.from_name(gt4py_backend)
is_optimal_layout = gt4py_backend_cls.storage_info["is_optimal_layout"]
Expand All @@ -104,21 +113,25 @@ def __init__(
]
)

self._data = (
data
if is_optimal_layout(data, dimensions)
else self._initialize_data(
# Assign data. Makes a copy if the layout isn't optimal for the given backend.
if is_optimal_layout(data, dimensions):
self._data = data
else:
warnings.warn(
f"Copying data to optimal layout for given backend {gt4py_backend}.",
UserWarning,
stacklevel=2,
)
self._data = self._initialize_data(
data,
origin=origin,
gt4py_backend=gt4py_backend,
dimensions=dimensions,
)
)
else:
# We have no info about the gt4py_backend, so just assign it.
self._data = data

_validate_quantity_property_lengths(data.shape, dims, origin, extent)
self._metadata = QuantityMetadata(
origin=_ensure_int_tuple(origin, "origin"),
extent=_ensure_int_tuple(extent, "extent"),
Expand Down Expand Up @@ -156,6 +169,14 @@ def from_data_array(
"""
if "units" not in data_array.attrs:
raise ValueError("need units attribute to create Quantity from DataArray")

if gt4py_backend is None:
warnings.warn(
"gt4py_backend will be mandatory in future releases.",
DeprecationWarning,
stacklevel=2,
)

return cls(
data_array.values,
cast(tuple[str], data_array.dims),
Expand Down
12 changes: 9 additions & 3 deletions tests/dsl/test_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def test_domain_size_comparison(
domain: tuple[int],
call_count: int,
):
quantity = Quantity(np.zeros(extent), dimensions, "n/a", extent=extent)
quantity = Quantity(
np.zeros(extent), dimensions, "n/a", extent=extent, gt4py_backend="debug"
)
stencil = FrozenStencil(
copy_stencil,
origin=(0, 0, 0),
Expand Down Expand Up @@ -147,7 +149,9 @@ def two_dim_temporaries_stencil(q_out: FloatField) -> None:

def test_stencil_2D_temporaries() -> None:
domain = (2, 2, 5)
quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain)
quantity = Quantity(
np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, gt4py_backend="debug"
)
stencil = FrozenStencil(
two_dim_temporaries_stencil,
origin=(0, 0, 0),
Expand All @@ -164,7 +168,9 @@ def test_stencil_2D_temporaries() -> None:
)
def test_validation_call_count(iterations: tuple[int]):
domain = (2, 2, 5)
quantity = Quantity(np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain)
quantity = Quantity(
np.zeros(domain), ["x", "y", "z"], "n/a", extent=domain, gt4py_backend="debug"
)
stencil_config = StencilConfig(
compilation_config=CompilationConfig(backend="numpy", rebuild=True)
)
Expand Down
2 changes: 2 additions & 0 deletions tests/mpi/test_mpi_halo_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def depth_quantity(
units=units,
origin=origin,
extent=extent,
gt4py_backend="debug",
)
return quantity

Expand Down Expand Up @@ -325,6 +326,7 @@ def zeros_quantity(dims, units, origin, extent, shape, numpy, dtype):
units=units,
origin=origin,
extent=extent,
gt4py_backend="debug",
)
quantity.view[:] = 0.0
return quantity
Expand Down
3 changes: 3 additions & 0 deletions tests/quantity/test_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_boundary_data_1_by_1_array_1_halo():
units="m",
origin=(1, 1),
extent=(1, 1),
gt4py_backend="debug",
)
for side in (
WEST,
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_boundary_data_3d_array_1_halo_z_offset_origin(numpy):
units="m",
origin=(1, 1, 1),
extent=(1, 1, 1),
gt4py_backend="debug",
)
for side in (
WEST,
Expand Down Expand Up @@ -109,6 +111,7 @@ def test_boundary_data_2_by_2_array_2_halo():
units="m",
origin=(2, 2),
extent=(2, 2),
gt4py_backend="debug",
)
for side in (
WEST,
Expand Down
3 changes: 3 additions & 0 deletions tests/quantity/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_deepcopy_copy_is_editable_by_view():
extent=(nx, ny, nz),
dims=["x", "y", "z"],
units="",
gt4py_backend="debug",
)
quantity_copy = copy.deepcopy(quantity)
# assertion below is only valid if we're overwriting the entire data through view
Expand All @@ -31,6 +32,7 @@ def test_deepcopy_copy_is_editable_by_data():
extent=(nx, ny, nz),
dims=["x", "y", "z"],
units="",
gt4py_backend="debug",
)
quantity_copy = copy.deepcopy(quantity)
quantity_copy.data[:] = 1.0
Expand All @@ -46,6 +48,7 @@ def test_deepcopy_of_dataclass_is_editable_by_data():
extent=(nx, ny, nz),
dims=["x", "y", "z"],
units="",
gt4py_backend="debug",
)
quantity_copy = copy.deepcopy(quantity)
quantity_copy.data[:] = 1.0
Expand Down
57 changes: 43 additions & 14 deletions tests/quantity/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype):

@pytest.fixture
def quantity(data, origin, extent, dims, units):
return Quantity(data, origin=origin, extent=extent, dims=dims, units=units)
return Quantity(
data,
origin=origin,
extent=extent,
dims=dims,
units=units,
gt4py_backend="debug",
)


def test_smaller_data_raises(data, origin, extent, dims, units):
Expand All @@ -73,23 +80,49 @@ def test_smaller_data_raises(data, origin, extent, dims, units):
else:
with pytest.raises(ValueError):
Quantity(
small_data, origin=origin, extent=extent, dims=dims, units=units
small_data,
origin=origin,
extent=extent,
dims=dims,
units=units,
gt4py_backend="debug",
)


def test_smaller_dims_raises(data, origin, extent, dims, units):
with pytest.raises(ValueError):
Quantity(data, origin=origin, extent=extent, dims=dims[:-1], units=units)
Quantity(
data,
origin=origin,
extent=extent,
dims=dims[:-1],
units=units,
gt4py_backend="debug",
)


def test_smaller_origin_raises(data, origin, extent, dims, units):
with pytest.raises(ValueError):
Quantity(data, origin=origin[:-1], extent=extent, dims=dims, units=units)
Quantity(
data,
origin=origin[:-1],
extent=extent,
dims=dims,
units=units,
gt4py_backend="debug",
)


def test_smaller_extent_raises(data, origin, extent, dims, units):
with pytest.raises(ValueError):
Quantity(data, origin=origin, extent=extent[:-1], dims=dims, units=units)
Quantity(
data,
origin=origin,
extent=extent[:-1],
dims=dims,
units=units,
gt4py_backend="debug",
)


def test_data_change_affects_quantity(data, quantity, numpy):
Expand Down Expand Up @@ -228,27 +261,23 @@ def test_shift_slice(slice_in, shift, extent, slice_out):
@pytest.mark.parametrize(
"quantity",
[
Quantity(np.array(5), dims=[], units="", gt4py_backend="debug"),
Quantity(
np.array(5),
dims=[],
units="",
),
Quantity(
np.array([1, 2, 3]),
dims=["dimension"],
units="degK",
np.array([1, 2, 3]), dims=["dimension"], units="degK", gt4py_backend="debug"
),
Quantity(
np.random.randn(3, 2, 4),
dims=["dim1", "dim_2", "dimension_3"],
units="m",
gt4py_backend="debug",
),
Quantity(
np.random.randn(8, 6, 6),
dims=["dim1", "dim_2", "dimension_3"],
units="km",
origin=(2, 2, 2),
extent=(4, 2, 2),
gt4py_backend="debug",
),
],
)
Expand All @@ -264,7 +293,7 @@ def test_to_data_array(quantity):


def test_data_setter():
quantity = Quantity(np.ones((5,)), dims=["dim1"], units="")
quantity = Quantity(np.ones((5,)), dims=["dim1"], units="", gt4py_backend="debug")

# After allocation - field and data are the same (origin is 0)
assert quantity.data.shape == quantity.field.shape
Expand Down
9 changes: 8 additions & 1 deletion tests/quantity/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ def data(n_halo, extent_1d, n_dims, numpy, dtype):

@pytest.fixture
def quantity(data, origin, extent, dims, units):
return Quantity(data, origin=origin, extent=extent, dims=dims, units=units)
return Quantity(
data,
origin=origin,
extent=extent,
dims=dims,
units=units,
gt4py_backend="debug",
)


def test_numpy(quantity, backend):
Expand Down
8 changes: 7 additions & 1 deletion tests/quantity/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def quantity(quantity_data_input, initial_dims, initial_origin, initial_extent):
units="unit_string",
origin=initial_origin,
extent=initial_extent,
gt4py_backend="debug",
)


Expand Down Expand Up @@ -213,7 +214,12 @@ def test_transpose_invalid_cases(


def test_transpose_retains_attrs(numpy):
quantity = Quantity(numpy.random.randn(3, 4), dims=["x", "y"], units="unit_string")
quantity = Quantity(
numpy.random.randn(3, 4),
dims=["x", "y"],
units="unit_string",
gt4py_backend="debug",
)
quantity._attrs = {"long_name": "500 mb height"}
transposed = quantity.transpose(["y", "x"])
assert transposed.attrs == quantity.attrs
Loading