From 970dad875b1fa374edfdc3b4461ff49a70c336fd Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 9 May 2025 14:22:17 -0400 Subject: [PATCH 1/6] Introduce `field` as a replacement for `quantity.view[:]` Rename `data_array` as `data_as_xarray` and introduce `field_as_xarray` --- ndsl/quantity/quantity.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 4f80fff1..4abe1bfb 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -239,6 +239,10 @@ def view(self) -> BoundedArrayView: """a view into the computational domain of the underlying data""" return self._compute_domain_view + @property + def field(self) -> np.ndarray | cupy.ndarray: + return self._compute_domain_view[:] + @property def data(self) -> Union[np.ndarray, cupy.ndarray]: """the underlying array of data""" @@ -260,16 +264,14 @@ def extent(self) -> Tuple[int, ...]: return self.metadata.extent @property - def data_array(self, full_data=False) -> xr.DataArray: - """Returns an Xarray.DataArray of the view (domain) + def field_as_xarray(self) -> xr.DataArray: + """Returns an Xarray.DataArray of the field (domain)""" + return xr.DataArray(self.field, dims=self.dims, attrs=self.attrs) - Args: - full_data: Return the entire data (halo included) instead of the view - """ - if full_data: - return xr.DataArray(self.data[:], dims=self.dims, attrs=self.attrs) - else: - return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) + @property + def data_as_xarray(self) -> xr.DataArray: + """Returns an Xarray.DataArray of the underlying array""" + return xr.DataArray(self.data, dims=self.dims, attrs=self.attrs) @property def np(self) -> NumpyModule: From c68bc9f2fd8e331d80a08354c730882ef6a852ef Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 9 May 2025 14:27:11 -0400 Subject: [PATCH 2/6] Unit tests --- tests/quantity/test_storage.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 2cdb8d49..a3853e5d 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -73,7 +73,7 @@ def test_numpy(quantity, backend): @pytest.mark.skipif(gt4py is None, reason="requires gt4py") -def test_modifying_numpy_data_modifies_view(): +def test_modifying_numpy_data_modifies_view_and_field(): shape = (6, 6) data = np.zeros(shape, dtype=float) quantity = Quantity( @@ -91,6 +91,9 @@ def test_modifying_numpy_data_modifies_view(): assert quantity.view[0, 0] == 1 assert quantity.view[2, 2] == 5 assert quantity.view[4, 4] == 3 + assert quantity.field[0, 0] == 1 + assert quantity.field[2, 2] == 5 + assert quantity.field[4, 4] == 3 assert quantity.data[0, 0] == 1 assert quantity.data[2, 2] == 5 assert quantity.data[4, 4] == 3 @@ -104,6 +107,14 @@ def test_data_exists(quantity, backend): assert isinstance(quantity.data, cp.ndarray) +@pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) +def test_field_exists(quantity, backend): + if "numpy" in backend: + assert isinstance(quantity.field, np.ndarray) + else: + assert isinstance(quantity.field, cp.ndarray) + + @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_accessing_data_does_not_break_view( data, origin, extent, dims, units, gt4py_backend @@ -118,6 +129,7 @@ def test_accessing_data_does_not_break_view( ) quantity.data[origin] = -1.0 assert quantity.data[origin] == quantity.view[tuple(0 for _ in origin)] + assert quantity.data[origin] == quantity.field[tuple(0 for _ in origin)] # run using cupy backend even though unused, to mark this as a "gpu" test From 68de0ceb10b5ab8567a6acb1ce9c8ebbc9048dbb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 9 May 2025 14:31:19 -0400 Subject: [PATCH 3/6] Fix usage of `data_array` --- ndsl/quantity/quantity.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 4abe1bfb..f958c7a0 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -153,9 +153,16 @@ def from_data_array( gt4py_backend=gt4py_backend, ) - def to_netcdf(self, path: str, name="var", rank: int = -1) -> None: + def to_netcdf(self, path: str, name="var", rank: int = -1, all_data=False) -> None: if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: - self.data_array.to_dataset(name=name).to_netcdf(f"{path}__r{rank}.nc4") + if all_data: + self.data_as_xarray.to_dataset(name=name).to_netcdf( + f"{path}__r{rank}.nc4" + ) + else: + self.field_as_xarray.to_dataset(name=name).to_netcdf( + f"{path}__r{rank}.nc4" + ) def halo_spec(self, n_halo: int) -> QuantityHaloSpec: return QuantityHaloSpec( From 2f06b653735857047fc826f7c79d14dfdbf0f5ac Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 9 May 2025 14:44:48 -0400 Subject: [PATCH 4/6] More fix --- ndsl/io.py | 4 ++-- tests/quantity/test_quantity.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ndsl/io.py b/ndsl/io.py index b07248bb..c3024b07 100644 --- a/ndsl/io.py +++ b/ndsl/io.py @@ -22,7 +22,7 @@ def to_xarray_dataset(state) -> xr.Dataset: data_vars = { - name: value.data_array for name, value in state.items() if name != "time" + name: value.data_as_xarray for name, value in state.items() if name != "time" } if "time" in state: data_vars["time"] = state["time"] @@ -47,7 +47,7 @@ def _extract_time(value: xr.DataArray) -> cftime.datetime: """Extract time value from read-in state.""" if value.ndim > 0: raise ValueError( - "State must be representative of a single scalar time. " f"Got {value}." + f"State must be representative of a single scalar time. Got {value}." ) time = value.item() if not isinstance(time, cftime.datetime): diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 61e92025..8ee04bc5 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -262,11 +262,11 @@ def test_shift_slice(slice_in, shift, extent, slice_out): ) @requires_xarray def test_to_data_array(quantity): - assert quantity.data_array.attrs == quantity.attrs - assert quantity.data_array.dims == quantity.dims - assert quantity.data_array.shape == quantity.extent - np.testing.assert_array_equal(quantity.data_array.values, quantity.view[:]) + assert quantity.field_as_xarray.attrs == quantity.attrs + assert quantity.field_as_xarray.dims == quantity.dims + assert quantity.field_as_xarray.shape == quantity.extent + np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: - assert ( - quantity.data_array.data.ctypes.data == quantity.data.ctypes.data - ), "data memory address is not equal" + assert quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data, ( + "data memory address is not equal" + ) From f446841c45bb93b97ead8a8bce2bc84094572e66 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 9 May 2025 14:48:14 -0400 Subject: [PATCH 5/6] Litn --- tests/quantity/test_quantity.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 8ee04bc5..44d46bf5 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -267,6 +267,6 @@ def test_to_data_array(quantity): assert quantity.field_as_xarray.shape == quantity.extent np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: - assert quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data, ( - "data memory address is not equal" - ) + assert ( + quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data + ), "data memory address is not equal" From b42969cad9b2aaf1a0a5a161a136e37cf579cde0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 13 May 2025 09:18:03 -0400 Subject: [PATCH 6/6] Add an halo read/write data/field test --- tests/quantity/test_storage.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index a3853e5d..8624d92b 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -99,6 +99,31 @@ def test_modifying_numpy_data_modifies_view_and_field(): assert quantity.data[4, 4] == 3 +@pytest.mark.skipif(gt4py is None, reason="requires gt4py") +def test_data_and_field_access_right_full_array_and_compute_domain(): + """Test halo read/write align with data (full array) and field (compute domain)""" + shape = (6, 6) + data = np.zeros(shape, dtype=float) + quantity = Quantity( + data, + origin=(1, 1), + extent=(5, 5), + dims=["dim1", "dim2"], + units="units", + gt4py_backend="numpy", + ) + assert np.all(quantity.data == 0) + # Write compute domain - test data is written with the offset + quantity.field[:] = 11.11 + assert np.all(quantity.field == 11.11) + assert np.all(quantity.data[1:-1, 1:-1] == 11.11) + assert np.all(quantity.data[0:1, 0:1] == 0) + # Write halo and test field has been left untouched + quantity.data[0:1, 0:1] = 33 + assert np.all(quantity.data[0:1, 0:1] == 33) + assert np.all(quantity.field == 11.11) + + @pytest.mark.parametrize("backend", ["numpy", "cupy"], indirect=True) def test_data_exists(quantity, backend): if "numpy" in backend: