diff --git a/ndsl/stencils/column_operations.py b/ndsl/stencils/column_operations.py index 21838ea2..25ede5de 100644 --- a/ndsl/stencils/column_operations.py +++ b/ndsl/stencils/column_operations.py @@ -28,6 +28,31 @@ def column_max(field, start_index, end_index): return field.at(K=max_index), max_index +@typing.no_type_check +@function +def column_max_ddim(field, ddim, start_index, end_index): + """ + Find the maximum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [max value, index of max value] + """ + max_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level, ddim=[ddim]) + old = field.at(K=max_index, ddim=[ddim]) + if new > old: + max_index = level + level += 1 + + return field.at(K=max_index, ddim=[ddim]), max_index + + @typing.no_type_check @function def column_min(field, start_index, end_index): @@ -51,3 +76,28 @@ def column_min(field, start_index, end_index): level += 1 return field.at(K=min_index), min_index + + +@typing.no_type_check +@function +def column_min_ddim(field, ddim, start_index, end_index): + """ + Find the minimum value for a full or slice of a column. + + Args: + field: data to be analyzed + start_index: "bottom" index of slice, must be less than end_index + end_index: "top" index of slice, must be greater than start_index + + Returns: [min value, index of min value] + """ + min_index = 0 + level = start_index + while level <= end_index: + new = field.at(K=level, ddim=[ddim]) + old = field.at(K=min_index, ddim=[ddim]) + if new < old: + min_index = level + level += 1 + + return field.at(K=min_index, ddim=[ddim]), min_index diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index 73802a23..48d2232a 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -5,9 +5,17 @@ from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import FORWARD, computation, interval -from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, set_4d_field_size from ndsl.stencils import CopyCornersXY -from ndsl.stencils.column_operations import column_max, column_min +from ndsl.stencils.column_operations import ( + column_max, + column_max_ddim, + column_min, + column_min_ddim, +) + + +FloatField_ddim = set_4d_field_size(2, Float) @pytest.fixture @@ -26,6 +34,14 @@ def column_max_stencil( with computation(FORWARD), interval(0, 1): max_value, max_index = column_max(data, 0, k_end) + def column_max_ddim_stencil( + data: FloatField_ddim, max_value: FloatFieldIJ, max_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + max_value, max_index = column_max_ddim(data, 1, 0, k_end) + def column_min_stencil( data: FloatField, min_value: FloatFieldIJ, min_index: FloatFieldIJ ): @@ -34,14 +50,30 @@ def column_min_stencil( with computation(FORWARD), interval(0, 1): min_value, min_index = column_min(data, 5, k_end) + def column_min_ddim_stencil( + data: FloatField_ddim, min_value: FloatFieldIJ, min_index: FloatFieldIJ + ): + from __externals__ import k_end + + with computation(FORWARD), interval(0, 1): + min_value, min_index = column_min_ddim(data, 1, 5, k_end) + self._column_max_stencil = stencil_factory.from_dims_halo( func=column_max_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) + self._column_max_ddim_stencil = stencil_factory.from_dims_halo( + func=column_max_ddim_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) self._column_min_stencil = stencil_factory.from_dims_halo( func=column_min_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) + self._column_min_ddim_stencil = stencil_factory.from_dims_halo( + func=column_min_ddim_stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) def __call__( self, @@ -50,13 +82,21 @@ def __call__( max_index: FloatFieldIJ, min_value: FloatFieldIJ, min_index: FloatFieldIJ, + data_ddim: FloatField_ddim, + max_value_ddim: FloatFieldIJ, + max_index_ddim: FloatFieldIJ, + min_value_ddim: FloatFieldIJ, + min_index_ddim: FloatFieldIJ, ): self._column_max_stencil(data, max_value, max_index) + self._column_max_ddim_stencil(data_ddim, max_value_ddim, max_index_ddim) self._column_min_stencil(data, min_value, min_index) + self._column_min_ddim_stencil(data_ddim, min_value_ddim, min_index_ddim) def test_column_operations(boilerplate): stencil_factory, quantity_factory = boilerplate + quantity_factory.add_data_dimensions({"ddim": 2}) data = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") data.field[:] = [ 47.3821, @@ -71,19 +111,61 @@ def test_column_operations(boilerplate): 7.3504, ] + data_ddim = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM, "ddim"], "n/a") + data_ddim.field[:] = [ + [ + [ + [47.3821, 27.4825], + [2.9157, 93.1242], + [88.6034, 14.6347], + [71.9275, 58.2094], + [53.1412, 6.2369], + [19.4783, 71.5457], + [94.2258, 42.3091], + [36.8099, 89.7718], + [64.0175, 63.1910], + [7.3504, 3.4991], + ] + ] + ] + max_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") max_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") min_value = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") min_index = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + max_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_value_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") + min_index_ddim = quantity_factory.zeros([X_DIM, Y_DIM], "n/a") code = ColumnOperations(stencil_factory) - code(data, max_value, max_index, min_value, min_index) - + code( + data, + max_value, + max_index, + min_value, + min_index, + data_ddim, + max_value_ddim, + max_index_ddim, + min_value_ddim, + min_index_ddim, + ) + + # 3d field tests assert max_value.field[:] == np.max(data.field[:], axis=2) assert max_index.field[:] == np.argmax(data.field[:], axis=2) assert min_value.field[:] == np.min(data.field[:, :, 5:], axis=2) assert min_index.field[:] == 5 + np.argmin(data.field[:, :, 5:], axis=2) + # 4d field tests + assert max_value_ddim.field[:] == np.max(data_ddim.field[:, :, :, 1], axis=2) + assert max_index_ddim.field[:] == np.argmax(data_ddim.field[:, :, :, 1], axis=2) + assert min_value_ddim.field[:] == np.min(data_ddim.field[:, :, 5:, 1], axis=2) + assert min_index_ddim.field[:] == 5 + np.argmin( + data_ddim.field[:, :, 5:, 1], axis=2 + ) + def test_CopyCornersXY_deprecation(boilerplate) -> None: stencil_factory, _ = boilerplate