Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _create_all_reduce_quantity(
units=input_metadata.units,
origin=input_metadata.origin,
extent=input_metadata.extent,
gt4py_backend=input_metadata.gt4py_backend,
backend=input_metadata.backend,
allow_mismatch_float_precision=False,
)
return all_reduce_quantity
Expand Down Expand Up @@ -228,7 +228,7 @@ def _get_gather_recv_quantity(
units=send_metadata.units,
origin=tuple([0 for dim in send_metadata.dims]),
extent=global_extent,
gt4py_backend=send_metadata.gt4py_backend,
backend=send_metadata.backend,
allow_mismatch_float_precision=True,
)
return recv_quantity
Expand All @@ -241,7 +241,7 @@ def _get_scatter_recv_quantity(
send_metadata.np.zeros(shape, dtype=send_metadata.dtype), # type: ignore
dims=send_metadata.dims,
units=send_metadata.units,
gt4py_backend=send_metadata.gt4py_backend,
backend=send_metadata.backend,
allow_mismatch_float_precision=True,
)
return recv_quantity
Expand Down Expand Up @@ -841,7 +841,7 @@ def _get_gather_recv_quantity(
units=metadata.units,
origin=(0,) + tuple([0 for dim in metadata.dims]),
extent=global_extent,
gt4py_backend=metadata.gt4py_backend,
backend=metadata.backend,
allow_mismatch_float_precision=True,
)
return recv_quantity
Expand All @@ -861,7 +861,7 @@ def _get_scatter_recv_quantity(
metadata.np.zeros(shape, dtype=metadata.dtype), # type: ignore
dims=metadata.dims[1:],
units=metadata.units,
gt4py_backend=metadata.gt4py_backend,
backend=metadata.backend,
allow_mismatch_float_precision=True,
)
return recv_quantity
2 changes: 1 addition & 1 deletion ndsl/dsl/ndsl_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,6 @@ def make_local(
units=quantity.units,
origin=quantity.origin,
extent=quantity.extent,
gt4py_backend=quantity.gt4py_backend,
backend=quantity.backend,
allow_mismatch_float_precision=allow_mismatch_float_precision,
)
24 changes: 12 additions & 12 deletions ndsl/grid/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def lon(self):
origin=self.grid.origin[:2],
extent=self.grid.extent[:2],
units=self.grid.units,
gt4py_backend=self.grid.gt4py_backend,
backend=self.grid.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -563,7 +563,7 @@ def lat(self) -> Quantity:
origin=self.grid.origin[:2],
extent=self.grid.extent[:2],
units=self.grid.units,
gt4py_backend=self.grid.gt4py_backend,
backend=self.grid.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -575,7 +575,7 @@ def lon_agrid(self) -> Quantity:
origin=self.agrid.origin[:2],
extent=self.agrid.extent[:2],
units=self.agrid.units,
gt4py_backend=self.agrid.gt4py_backend,
backend=self.agrid.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -587,7 +587,7 @@ def lat_agrid(self) -> Quantity:
origin=self.agrid.origin[:2],
extent=self.agrid.extent[:2],
units=self.agrid.units,
gt4py_backend=self.agrid.gt4py_backend,
backend=self.agrid.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand Down Expand Up @@ -1551,7 +1551,7 @@ def rarea(self) -> Quantity:
origin=self.area.origin,
extent=self.area.extent,
units="m^-2",
gt4py_backend=self.area.gt4py_backend,
backend=self.area.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1566,7 +1566,7 @@ def rarea_c(self) -> Quantity:
origin=self.area_c.origin,
extent=self.area_c.extent,
units="m^-2",
gt4py_backend=self.area_c.gt4py_backend,
backend=self.area_c.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1582,7 +1582,7 @@ def rdx(self) -> Quantity:
origin=self.dx.origin,
extent=self.dx.extent,
units="m^-1",
gt4py_backend=self.dx.gt4py_backend,
backend=self.dx.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1598,7 +1598,7 @@ def rdy(self) -> Quantity:
origin=self.dy.origin,
extent=self.dy.extent,
units="m^-1",
gt4py_backend=self.dy.gt4py_backend,
backend=self.dy.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1614,7 +1614,7 @@ def rdxa(self) -> Quantity:
origin=self.dxa.origin,
extent=self.dxa.extent,
units="m^-1",
gt4py_backend=self.dxa.gt4py_backend,
backend=self.dxa.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1630,7 +1630,7 @@ def rdya(self) -> Quantity:
origin=self.dya.origin,
extent=self.dya.extent,
units="m^-1",
gt4py_backend=self.dya.gt4py_backend,
backend=self.dya.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1646,7 +1646,7 @@ def rdxc(self) -> Quantity:
origin=self.dxc.origin,
extent=self.dxc.extent,
units="m^-1",
gt4py_backend=self.dxc.gt4py_backend,
backend=self.dxc.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand All @@ -1662,7 +1662,7 @@ def rdyc(self) -> Quantity:
origin=self.dyc.origin,
extent=self.dyc.extent,
units="m^-1",
gt4py_backend=self.dyc.gt4py_backend,
backend=self.dyc.backend,
number_of_halo_points=N_HALO_DEFAULT,
)

Expand Down
12 changes: 6 additions & 6 deletions ndsl/grid/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def p_interface(self) -> Quantity:
p_interface_data,
dims=[Z_INTERFACE_DIM],
units="Pa",
gt4py_backend=self.ak.gt4py_backend,
backend=self.ak.backend,
number_of_halo_points=self.ak.metadata.n_halo,
)
return self._p_interface
Expand All @@ -203,7 +203,7 @@ def p(self) -> Quantity:
p_data,
dims=[Z_DIM],
units="Pa",
gt4py_backend=self.p_interface.gt4py_backend,
backend=self.p_interface.backend,
number_of_halo_points=self.p_interface.metadata.n_halo,
)
return self._p
Expand All @@ -220,7 +220,7 @@ def dp(self) -> Quantity:
dp_ref_data,
dims=[Z_DIM],
units="Pa",
gt4py_backend=self.ak.gt4py_backend,
backend=self.ak.backend,
number_of_halo_points=self.ak.metadata.n_halo,
)
return self._dp_ref
Expand All @@ -230,7 +230,7 @@ def ptop(self) -> Float:
"""Top of atmosphere pressure (Pa)"""
if self.bk.view[0] != 0:
raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0")
if self.ak.gt4py_backend is not None and is_gpu_backend(self.ak.gt4py_backend):
if self.ak.backend is not None and is_gpu_backend(self.ak.backend):
return Float(self.ak.view[0].get())
else:
return Float(self.ak.view[0])
Expand Down Expand Up @@ -382,7 +382,7 @@ def _fC_from_data(data, lat: Quantity) -> Quantity:
dims=lat.dims,
origin=lat.origin,
extent=lat.extent,
gt4py_backend=lat.gt4py_backend,
backend=lat.backend,
number_of_halo_points=lat.metadata.n_halo,
)

Expand Down Expand Up @@ -824,7 +824,7 @@ def split_quantity_along_last_dim(quantity: Quantity) -> list[Quantity]:
units=quantity.units,
origin=quantity.origin[:-1],
extent=quantity.extent[:-1],
gt4py_backend=quantity.gt4py_backend,
backend=quantity.backend,
number_of_halo_points=quantity.metadata.n_halo,
)
)
Expand Down
2 changes: 1 addition & 1 deletion ndsl/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _allocate(
units=units,
origin=origin,
extent=extent,
gt4py_backend=self.backend,
backend=self.backend,
allow_mismatch_float_precision=allow_mismatch_float_precision,
number_of_halo_points=self.sizer.n_halo,
)
Expand Down
24 changes: 18 additions & 6 deletions ndsl/quantity/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Sequence
import warnings
from collections.abc import Sequence
from typing import Any

import dace
import numpy as np
Expand All @@ -20,21 +22,31 @@ def __init__(
data: np.ndarray | cupy.ndarray,
dims: Sequence[str],
units: str,
*,
origin: Sequence[int] | None = None,
extent: Sequence[int] | None = None,
gt4py_backend: str | None = None,
allow_mismatch_float_precision: bool = False,
backend: str | None = None,
):
if gt4py_backend is not None:
warnings.warn(
"gt4py_backend is deprecated. Use `backend` instead.",
DeprecationWarning,
stacklevel=2,
)
if backend is None:
backend = gt4py_backend

super().__init__(
data,
dims,
units,
origin,
extent,
gt4py_backend,
allow_mismatch_float_precision,
origin=origin,
extent=extent,
allow_mismatch_float_precision=allow_mismatch_float_precision,
backend=backend,
)
self._transient = True

def __descriptor__(self) -> Any:
"""Locals uses `Quantity.__descriptor__` and flag itself as transient."""
Expand Down
5 changes: 4 additions & 1 deletion ndsl/quantity/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class QuantityMetadata:
dtype: type
"dtype of the data in the ndarray-like object"
gt4py_backend: str | None = None
"backend to use for gt4py storages"
"Deprecated. Use backend instead."
backend: str | None = None
"GT4Py backend name. Used for performance optimal data allocation."

@property
def dim_lengths(self) -> dict[str, int]:
Expand All @@ -57,6 +59,7 @@ def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None:
metadata_copy.data_type = self.data_type
metadata_copy.dtype = self.dtype
metadata_copy.gt4py_backend = self.gt4py_backend
metadata_copy.backend = self.backend


@dataclasses.dataclass
Expand Down
Loading