Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ndsl/stencils/column_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,31 @@ def column_max(field, start_index, end_index):
return field.at(K=max_index), max_index


@typing.no_type_check
@function
def column_max_ddim(field, ddim, start_index, end_index):
Comment thread
romanc marked this conversation as resolved.
"""
Find the maximum value for a full or slice of a column.

Args:
field: data to be analyzed
start_index: "bottom" index of slice, must be less than end_index
end_index: "top" index of slice, must be greater than start_index

Returns: [max value, index of max value]
"""
max_index = 0
level = start_index
while level <= end_index:
new = field.at(K=level, ddim=[ddim])
old = field.at(K=max_index, ddim=[ddim])
if new > old:
max_index = level
level += 1

return field.at(K=max_index, ddim=[ddim]), max_index


@typing.no_type_check
@function
def column_min(field, start_index, end_index):
Expand All @@ -51,3 +76,28 @@ def column_min(field, start_index, end_index):
level += 1

return field.at(K=min_index), min_index


@typing.no_type_check
@function
def column_min_ddim(field, ddim, start_index, end_index):
"""
Find the minimum value for a full or slice of a column.

Args:
field: data to be analyzed
start_index: "bottom" index of slice, must be less than end_index
end_index: "top" index of slice, must be greater than start_index

Returns: [min value, index of min value]
"""
min_index = 0
level = start_index
while level <= end_index:
new = field.at(K=level, ddim=[ddim])
old = field.at(K=min_index, ddim=[ddim])
if new < old:
min_index = level
level += 1

return field.at(K=min_index, ddim=[ddim]), min_index
90 changes: 86 additions & 4 deletions tests/stencils/test_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
from ndsl.boilerplate import get_factories_single_tile
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.gt4py import FORWARD, computation, interval
from ndsl.dsl.typing import FloatField, FloatFieldIJ
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, set_4d_field_size
from ndsl.stencils import CopyCornersXY
from ndsl.stencils.column_operations import column_max, column_min
from ndsl.stencils.column_operations import (
column_max,
column_max_ddim,
column_min,
column_min_ddim,
)


FloatField_ddim = set_4d_field_size(2, Float)


@pytest.fixture
Expand All @@ -26,6 +34,14 @@ def column_max_stencil(
with computation(FORWARD), interval(0, 1):
max_value, max_index = column_max(data, 0, k_end)

def column_max_ddim_stencil(
data: FloatField_ddim, max_value: FloatFieldIJ, max_index: FloatFieldIJ
):
from __externals__ import k_end

with computation(FORWARD), interval(0, 1):
max_value, max_index = column_max_ddim(data, 1, 0, k_end)

def column_min_stencil(
data: FloatField, min_value: FloatFieldIJ, min_index: FloatFieldIJ
):
Expand All @@ -34,14 +50,30 @@ def column_min_stencil(
with computation(FORWARD), interval(0, 1):
min_value, min_index = column_min(data, 5, k_end)

def column_min_ddim_stencil(
data: FloatField_ddim, min_value: FloatFieldIJ, min_index: FloatFieldIJ
):
from __externals__ import k_end

with computation(FORWARD), interval(0, 1):
min_value, min_index = column_min_ddim(data, 1, 5, k_end)

self._column_max_stencil = stencil_factory.from_dims_halo(
func=column_max_stencil,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)
self._column_max_ddim_stencil = stencil_factory.from_dims_halo(
func=column_max_ddim_stencil,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)
self._column_min_stencil = stencil_factory.from_dims_halo(
func=column_min_stencil,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)
self._column_min_ddim_stencil = stencil_factory.from_dims_halo(
func=column_min_ddim_stencil,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)

def __call__(
self,
Expand All @@ -50,13 +82,21 @@ def __call__(
max_index: FloatFieldIJ,
min_value: FloatFieldIJ,
min_index: FloatFieldIJ,
data_ddim: FloatField_ddim,
max_value_ddim: FloatFieldIJ,
max_index_ddim: FloatFieldIJ,
min_value_ddim: FloatFieldIJ,
min_index_ddim: FloatFieldIJ,
):
self._column_max_stencil(data, max_value, max_index)
self._column_max_ddim_stencil(data_ddim, max_value_ddim, max_index_ddim)
self._column_min_stencil(data, min_value, min_index)
self._column_min_ddim_stencil(data_ddim, min_value_ddim, min_index_ddim)


def test_column_operations(boilerplate):
stencil_factory, quantity_factory = boilerplate
quantity_factory.add_data_dimensions({"ddim": 2})
data = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a")
data.field[:] = [
47.3821,
Expand All @@ -71,19 +111,61 @@ def test_column_operations(boilerplate):
7.3504,
]

data_ddim = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM, "ddim"], "n/a")
data_ddim.field[:] = [
[
[
[47.3821, 27.4825],
[2.9157, 93.1242],
[88.6034, 14.6347],
[71.9275, 58.2094],
[53.1412, 6.2369],
[19.4783, 71.5457],
[94.2258, 42.3091],
[36.8099, 89.7718],
[64.0175, 63.1910],
[7.3504, 3.4991],
]
]
]

max_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
max_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
min_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
min_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
max_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
max_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
min_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")
min_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a")

code = ColumnOperations(stencil_factory)
code(data, max_value, max_index, min_value, min_index)

code(
data,
max_value,
max_index,
min_value,
min_index,
data_ddim,
max_value_ddim,
max_index_ddim,
min_value_ddim,
min_index_ddim,
)

# 3d field tests
assert max_value.field[:] == np.max(data.field[:], axis=2)
assert max_index.field[:] == np.argmax(data.field[:], axis=2)
assert min_value.field[:] == np.min(data.field[:, :, 5:], axis=2)
assert min_index.field[:] == 5 + np.argmin(data.field[:, :, 5:], axis=2)

# 4d field tests
assert max_value_ddim.field[:] == np.max(data_ddim.field[:, :, :, 1], axis=2)
assert max_index_ddim.field[:] == np.argmax(data_ddim.field[:, :, :, 1], axis=2)
assert min_value_ddim.field[:] == np.min(data_ddim.field[:, :, 5:, 1], axis=2)
assert min_index_ddim.field[:] == 5 + np.argmin(
data_ddim.field[:, :, 5:, 1], axis=2
)


def test_CopyCornersXY_deprecation(boilerplate) -> None:
stencil_factory, _ = boilerplate
Expand Down