Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/NDSL/03_orchestration_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
" orchestrate,\n",
" QuantityFactory,\n",
")\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n",
"from ndsl.constants import I_DIM, J_DIM, K_DIM\n",
"from ndsl.dsl.typing import FloatField, Float\n",
"from ndsl.boilerplate import get_factories_single_tile_orchestrated"
]
Expand Down Expand Up @@ -93,7 +93,7 @@
" domain=grid_indexing.domain_compute(),\n",
" )\n",
" self._tmp_field = quantity_factory.zeros(\n",
" [X_DIM, Y_DIM, Z_DIM], \"n/a\", dtype=dtype\n",
" [I_DIM, J_DIM, K_DIM], \"n/a\", dtype=dtype\n",
" )\n",
" self._n_halo = quantity_factory.sizer.n_halo\n",
"\n",
Expand Down Expand Up @@ -134,9 +134,9 @@
" )\n",
" local_sum = LocalSum(stencil_factory, qty_factory)\n",
"\n",
" in_field = qty_factory.zeros([X_DIM, Y_DIM, Z_DIM], \"n/a\", dtype=dtype)\n",
" in_field = qty_factory.zeros([I_DIM, J_DIM, K_DIM], \"n/a\", dtype=dtype)\n",
" in_field.view[:] = 2.0\n",
" out_field = qty_factory.zeros([X_DIM, Y_DIM, Z_DIM], \"n/a\", dtype=dtype)\n",
" out_field = qty_factory.zeros([I_DIM, J_DIM, K_DIM], \"n/a\", dtype=dtype)\n",
"\n",
" # Run\n",
" local_sum(in_field, out_field)"
Expand Down
14 changes: 7 additions & 7 deletions ndsl/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def _get_constant_version(
# Common constants
#####################

I_DIM = X_DIM = "i"
I_INTERFACE_DIM = X_INTERFACE_DIM = "i_interface"
J_DIM = Y_DIM = "j"
J_INTERFACE_DIM = Y_INTERFACE_DIM = "j_interface"
K_DIM = Z_DIM = "k"
K_INTERFACE_DIM = Z_INTERFACE_DIM = "k_interface"
K_SOIL_DIM = Z_SOIL_DIM = "k_soil"
I_DIM = "i"
I_INTERFACE_DIM = "i_interface"
J_DIM = "j"
J_INTERFACE_DIM = "j_interface"
K_DIM = "k"
K_INTERFACE_DIM = "k_interface"
K_SOIL_DIM = "k_soil"

I_DIMS = (I_DIM, I_INTERFACE_DIM)
J_DIMS = (J_DIM, J_INTERFACE_DIM)
Expand Down
20 changes: 0 additions & 20 deletions ndsl/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from collections.abc import Callable, Sequence
from functools import wraps
from typing import Any
Expand Down Expand Up @@ -447,25 +446,6 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None):
return cp.asarray(array, dtype, order)


def is_gpu_backend(backend: Backend) -> bool:
warnings.warn(
"Function `gt4py_utils.is_gpu_backend` is deprecated, please use `Backend.is_gpu_backend()`",
category=DeprecationWarning,
stacklevel=2,
)
return backend.is_gpu_backend()


def backend_is_fortran_aligned(backend: Backend) -> bool:
warnings.warn(
"Function `gt4py_utils.backend_is_fortran_aligned` is deprecated "
"please use `Backend.backend_is_fortran_aligned()`",
category=DeprecationWarning,
stacklevel=2,
)
return backend.is_fortran_aligned()


def zeros(shape, dtype=Float, *, backend: Backend):
storage_type = cp.ndarray if backend.is_gpu_backend() else np.ndarray
xp = cp if cp and storage_type is cp.ndarray else np
Expand Down
27 changes: 4 additions & 23 deletions ndsl/stencils/corners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Literal, TypeAlias, no_type_check

from gt4py.cartesian import gtscript
Expand All @@ -10,19 +9,10 @@
from ndsl.dsl.typing import FloatField


FillCornersDirection: TypeAlias = Literal["i", "x", "j", "y"]
FillCornersDirection: TypeAlias = Literal["i", "j"]
GridType: TypeAlias = Literal["A", "B"] # Arakawa grid type


def _check_for_deprecation(axis: str) -> None:
if axis in ["x", "y"]:
warnings.warn(
f"Corners direction {axis} is deprecated use 'i' or 'j'",
category=DeprecationWarning,
stacklevel=2,
)


def kslice_from_inputs(
kstart: int, nk: int | None, grid_indexer: GridIndexing
) -> tuple[slice, int]:
Expand Down Expand Up @@ -365,13 +355,12 @@ def __init__(
domain = default_domain
"""The full domain required to do corner computation everywhere"""

_check_for_deprecation(direction)
if direction in ["x", "i"]:
if direction in ["i"]:
defn = fill_corners_bgrid_x_defn
elif direction in ["y", "j"]:
elif direction in ["j"]:
defn = fill_corners_bgrid_y_defn
else:
raise ValueError("Direction must be either 'x' or 'y'")
raise ValueError("Direction must be either 'i' or 'j'")
externals = stencil_factory.grid_indexing.axis_offsets(
origin=origin, domain=domain
)
Expand Down Expand Up @@ -519,7 +508,6 @@ def fill_sw_corner_2d_bgrid(
direction: FillCornersDirection,
grid_indexer: GridIndexing,
) -> None:
_check_for_deprecation(direction)
if direction in ["x", "i"]:
q[grid_indexer.isc - i, grid_indexer.jsc - j, :] = q[
grid_indexer.isc - j, grid_indexer.jsc + i, :
Expand All @@ -537,7 +525,6 @@ def fill_nw_corner_2d_bgrid(
direction: FillCornersDirection,
grid_indexer: GridIndexing,
) -> None:
_check_for_deprecation(direction)
if direction in ["x", "i"]:
q[grid_indexer.isc - i, grid_indexer.jec + 1 + j, :] = q[
grid_indexer.isc - j, grid_indexer.jec + 1 - i, :
Expand All @@ -555,7 +542,6 @@ def fill_se_corner_2d_bgrid(
direction: FillCornersDirection,
grid_indexer: GridIndexing,
) -> None:
_check_for_deprecation(direction)
if direction in ["x", "i"]:
q[grid_indexer.iec + 1 + i, grid_indexer.jsc - j, :] = q[
grid_indexer.iec + 1 + j, grid_indexer.jsc + i, :
Expand All @@ -573,7 +559,6 @@ def fill_ne_corner_2d_bgrid(
direction: FillCornersDirection,
grid_indexer: GridIndexing,
) -> None:
_check_for_deprecation(direction)
if direction in ["x", "i"]:
q[grid_indexer.iec + 1 + i, grid_indexer.jec + 1 + j :] = q[
grid_indexer.iec + 1 + j, grid_indexer.jec + 1 - i, :
Expand All @@ -593,7 +578,6 @@ def fill_sw_corner_2d_agrid(
kstart: int = 0,
nk: int | None = None,
) -> None:
_check_for_deprecation(direction)
kslice, nk = kslice_from_inputs(kstart, nk, grid_indexer)
if direction in ["x", "i"]:
q[grid_indexer.isc - i, grid_indexer.jsc - j, kslice] = q[
Expand All @@ -614,7 +598,6 @@ def fill_nw_corner_2d_agrid(
kstart: int = 0,
nk: int | None = None,
) -> None:
_check_for_deprecation(direction)
kslice, nk = kslice_from_inputs(kstart, nk, grid_indexer)
if direction in ["x", "i"]:
q[grid_indexer.isc - i, grid_indexer.jec + j, kslice] = q[
Expand All @@ -635,7 +618,6 @@ def fill_se_corner_2d_agrid(
kstart: int = 0,
nk: int | None = None,
) -> None:
_check_for_deprecation(direction)
kslice, nk = kslice_from_inputs(kstart, nk, grid_indexer)
if direction in ["x", "i"]:
q[grid_indexer.iec + i, grid_indexer.jsc - j, kslice] = q[
Expand All @@ -656,7 +638,6 @@ def fill_ne_corner_2d_agrid(
kstart: int = 0,
nk: int | None = None,
) -> None:
_check_for_deprecation(direction)
kslice, nk = kslice_from_inputs(kstart, nk, grid_indexer)
if direction in ["x", "i"]:
q[grid_indexer.iec + i, grid_indexer.jec + j, kslice] = q[
Expand Down
3 changes: 1 addition & 2 deletions tests/test_zarr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
J_INTERFACE_DIM,
K_DIM,
K_SOIL_DIM,
X_DIM,
)
from ndsl.monitor.zarr_monitor import ZarrMonitor, array_chunks, get_calendar
from ndsl.optional_imports import zarr
Expand Down Expand Up @@ -95,7 +94,7 @@ def base_state(request, nz, ny, nx, numpy) -> dict:
return {
"var1": Quantity(
numpy.ones([ny, nx]),
dims=(J_DIM, X_DIM),
dims=(J_DIM, I_DIM),
units="m",
backend=Backend.python(),
)
Expand Down