Skip to content
4 changes: 2 additions & 2 deletions ndsl/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def to_xarray_dataset(state) -> xr.Dataset:
data_vars = {
name: value.data_array for name, value in state.items() if name != "time"
name: value.data_as_xarray for name, value in state.items() if name != "time"
}
if "time" in state:
data_vars["time"] = state["time"]
Expand All @@ -47,7 +47,7 @@ def _extract_time(value: xr.DataArray) -> cftime.datetime:
"""Extract time value from read-in state."""
if value.ndim > 0:
raise ValueError(
"State must be representative of a single scalar time. " f"Got {value}."
f"State must be representative of a single scalar time. Got {value}."
)
time = value.item()
if not isinstance(time, cftime.datetime):
Expand Down
31 changes: 20 additions & 11 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,16 @@ def from_data_array(
gt4py_backend=gt4py_backend,
)

def to_netcdf(self, path: str, name="var", rank: int = -1) -> None:
def to_netcdf(self, path: str, name="var", rank: int = -1, all_data=False) -> None:
if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank:
self.data_array.to_dataset(name=name).to_netcdf(f"{path}__r{rank}.nc4")
if all_data:
self.data_as_xarray.to_dataset(name=name).to_netcdf(
f"{path}__r{rank}.nc4"
)
else:
self.field_as_xarray.to_dataset(name=name).to_netcdf(
f"{path}__r{rank}.nc4"
)

def halo_spec(self, n_halo: int) -> QuantityHaloSpec:
return QuantityHaloSpec(
Expand Down Expand Up @@ -239,6 +246,10 @@ def view(self) -> BoundedArrayView:
"""a view into the computational domain of the underlying data"""
return self._compute_domain_view

@property
def field(self) -> np.ndarray | cupy.ndarray:
return self._compute_domain_view[:]

@property
def data(self) -> Union[np.ndarray, cupy.ndarray]:
"""the underlying array of data"""
Expand All @@ -260,16 +271,14 @@ def extent(self) -> Tuple[int, ...]:
return self.metadata.extent

@property
def data_array(self, full_data=False) -> xr.DataArray:
"""Returns an Xarray.DataArray of the view (domain)
def field_as_xarray(self) -> xr.DataArray:
"""Returns an Xarray.DataArray of the field (domain)"""
return xr.DataArray(self.field, dims=self.dims, attrs=self.attrs)

Args:
full_data: Return the entire data (halo included) instead of the view
"""
if full_data:
return xr.DataArray(self.data[:], dims=self.dims, attrs=self.attrs)
else:
return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs)
@property
def data_as_xarray(self) -> xr.DataArray:
"""Returns an Xarray.DataArray of the underlying array"""
return xr.DataArray(self.data, dims=self.dims, attrs=self.attrs)

@property
def np(self) -> NumpyModule:
Expand Down
10 changes: 5 additions & 5 deletions tests/quantity/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ def test_shift_slice(slice_in, shift, extent, slice_out):
)
@requires_xarray
def test_to_data_array(quantity):
assert quantity.data_array.attrs == quantity.attrs
assert quantity.data_array.dims == quantity.dims
assert quantity.data_array.shape == quantity.extent
np.testing.assert_array_equal(quantity.data_array.values, quantity.view[:])
assert quantity.field_as_xarray.attrs == quantity.attrs
assert quantity.field_as_xarray.dims == quantity.dims
assert quantity.field_as_xarray.shape == quantity.extent
np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:])
if quantity.extent == quantity.data.shape:
assert (
quantity.data_array.data.ctypes.data == quantity.data.ctypes.data
quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data
), "data memory address is not equal"
39 changes: 38 additions & 1 deletion tests/quantity/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_numpy(quantity, backend):


@pytest.mark.skipif(gt4py is None, reason="requires gt4py")
def test_modifying_numpy_data_modifies_view():
def test_modifying_numpy_data_modifies_view_and_field():
shape = (6, 6)
data = np.zeros(shape, dtype=float)
quantity = Quantity(
Expand All @@ -91,11 +91,39 @@ def test_modifying_numpy_data_modifies_view():
assert quantity.view[0, 0] == 1
assert quantity.view[2, 2] == 5
assert quantity.view[4, 4] == 3
assert quantity.field[0, 0] == 1
assert quantity.field[2, 2] == 5
assert quantity.field[4, 4] == 3
assert quantity.data[0, 0] == 1
assert quantity.data[2, 2] == 5
assert quantity.data[4, 4] == 3
Comment on lines 97 to 99

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe someday it will make sense to do this with a halo and we can test that the data is different than the field or view

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your wishes are my command*.

*Subject to budget constraint, time and implementation technical details. Offer can be removed at any time for no reasons. Definition of "wishes" and "command" subject to change.



@pytest.mark.skipif(gt4py is None, reason="requires gt4py")
def test_data_and_field_access_right_full_array_and_compute_domain():
"""Test halo read/write align with data (full array) and field (compute domain)"""
shape = (6, 6)
data = np.zeros(shape, dtype=float)
quantity = Quantity(
data,
origin=(1, 1),
extent=(5, 5),
dims=["dim1", "dim2"],
units="units",
gt4py_backend="numpy",
)
assert np.all(quantity.data == 0)
# Write compute domain - test data is written with the offset
quantity.field[:] = 11.11
assert np.all(quantity.field == 11.11)
assert np.all(quantity.data[1:-1, 1:-1] == 11.11)
assert np.all(quantity.data[0:1, 0:1] == 0)
# Write halo and test field has been left untouched
quantity.data[0:1, 0:1] = 33
assert np.all(quantity.data[0:1, 0:1] == 33)
assert np.all(quantity.field == 11.11)


@pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True)
def test_data_exists(quantity, backend):
if "numpy" in backend:
Expand All @@ -104,6 +132,14 @@ def test_data_exists(quantity, backend):
assert isinstance(quantity.data, cp.ndarray)


@pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True)
def test_field_exists(quantity, backend):
if "numpy" in backend:
assert isinstance(quantity.field, np.ndarray)
else:
assert isinstance(quantity.field, cp.ndarray)


@pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True)
def test_accessing_data_does_not_break_view(
data, origin, extent, dims, units, gt4py_backend
Expand All @@ -118,6 +154,7 @@ def test_accessing_data_does_not_break_view(
)
quantity.data[origin] = -1.0
assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)]
assert quantity.data[origin] == quantity.field[tuple(0 for _ in origin)]


# run using cupy backend even though unused, to mark this as a "gpu" test
Expand Down