diff --git a/improver/cli/recursive_filter.py b/improver/cli/recursive_filter.py index 5e14d63fd3..a6cb4726a1 100755 --- a/improver/cli/recursive_filter.py +++ b/improver/cli/recursive_filter.py @@ -16,6 +16,7 @@ def process( *, iterations: int = 1, variable_mask: bool = False, + mask_zeros: bool = False, ): """Module to apply a recursive filter to neighbourhooded data. @@ -50,6 +51,11 @@ def process( different mask. If False and cube is masked, a check will be made that the same mask is present on each spatial slice. If True, each spatial slice of cube may contain a different spatial mask. + mask_zeros (bool): + If set true all of the values of 0 in the cube will be masked, + stopping the recursive filter from spreading values into these areas. + They will then be unmasked later on. If the input cube was masked + this mask will be reapplied to the output at the end. Returns: iris.cube.Cube: @@ -59,5 +65,8 @@ def process( plugin = RecursiveFilter(iterations=iterations) return plugin( - cube, smoothing_coefficients=smoothing_coefficients, variable_mask=variable_mask + cube, + smoothing_coefficients=smoothing_coefficients, + variable_mask=variable_mask, + mask_zeros=mask_zeros, ) diff --git a/improver/nbhood/recursive_filter.py b/improver/nbhood/recursive_filter.py index 01ae05f005..cb876ea512 100644 --- a/improver/nbhood/recursive_filter.py +++ b/improver/nbhood/recursive_filter.py @@ -382,10 +382,6 @@ def process( else: cube_mask = None - if mask_zeros: - cube.data = np.ma.masked_where(cube.data == 0.0, cube.data, copy=False) - # This masks any array element that is zero - cube_format = next(cube.slices([cube.coord(axis="y"), cube.coord(axis="x")])) coeffs_x, coeffs_y = self._validate_coefficients( cube_format, smoothing_coefficients @@ -402,6 +398,11 @@ def process( "Input cube contains spatial slices with different masks." ) + # This masks any array element that is zero. Performed after variable array + # check as zeros may not be located consistently across slices. + if mask_zeros: + cube.data = np.ma.masked_where(cube.data == 0.0, cube.data, copy=False) + recursed_cube = iris.cube.CubeList() for cslice in cube.slices([cube.coord(axis="y"), cube.coord(axis="x")]): padded_cube = pad_cube_with_halo( diff --git a/improver_tests/acceptance/SHA256SUMS b/improver_tests/acceptance/SHA256SUMS index f59d51ecb7..13314fb6ae 100644 --- a/improver_tests/acceptance/SHA256SUMS +++ b/improver_tests/acceptance/SHA256SUMS @@ -839,9 +839,11 @@ f69103cececd76e27bbff5a96e9c74c0e708dcb7f18459ade3eb448639992b34 ./precipitatio ae048c636992e80b79c6cbb44b36339b30ea8d0ef1db72cd3f4de8766346fa1d ./recursive-filter/input.nc b6cdb8bf877bb0b3b78ad224b50b9272b65732bf9e39a88df704209e228bf4c0 ./recursive-filter/input_masked.nc 11c428f6fb0202ab0f975e58e52d17342c50f607aee4fd0e387a2a62c188790e ./recursive-filter/input_variable_masked.nc +ce0ff757524da235a94421a3e2a254958811c562c739a7a73c3fa9cd4d00ddb5 ./recursive-filter/input_with_zeros.nc b4f7acb4fb95640f50a11cdb02038a1e66f2bc02d65b87230ff5a1dab649f141 ./recursive-filter/kgo_basic.nc 0f5ae62721603eb258cd54b195673be0fde0effbecbae7bf88100913a603bd38 ./recursive-filter/kgo_internal_mask_with_re_mask.nc 6e77a397bea8fd914a182ed9152b9b75b9e751968c432d1d45ffe971c6f8a815 ./recursive-filter/kgo_variable_internal_mask_with_re_mask.nc +a0660ea5d5540b9476e8b77d5f488db2f36842c0da3a1d068d52986353fe445f ./recursive-filter/kgo_with_zeros.nc 49497750007d283c609d8d1e1415f9a8be1f1c8b9f4c5c74e3b741ab3c3f681e ./recursive-filter/smoothing_coefficients.nc f4128d6ea8da41f46eb3b86f49cd6a60a0c81dabced89f1016372c20230fd85d ./regrid/basic/kgo.nc b431242e8abec923d1ad6d54022e3602511274e8b35bab81fa63a2383a9020b4 ./regrid/bilinear_2/kgo_multi_realization.nc diff --git a/improver_tests/acceptance/test_recursive_filter.py b/improver_tests/acceptance/test_recursive_filter.py index f62514f023..ed7299b7be 100644 --- a/improver_tests/acceptance/test_recursive_filter.py +++ b/improver_tests/acceptance/test_recursive_filter.py @@ -55,3 +55,21 @@ def test_variable_internal_mask(tmp_path): ] run_cli(args) acc.compare(output_path, kgo_path) + + +def test_zero_masking(tmp_path): + """Test recursive filter with zero masking preserves areas of zeros.""" + kgo_dir = acc.kgo_root() / "recursive-filter" + kgo_path = kgo_dir / "kgo_with_zeros.nc" + input_path = kgo_dir / "input_with_zeros.nc" + smoothing_coefficients_path = kgo_dir / "smoothing_coefficients.nc" + output_path = tmp_path / "output.nc" + args = [ + input_path, + smoothing_coefficients_path, + "--mask-zeros", + "--output", + output_path, + ] + run_cli(args) + acc.compare(output_path, kgo_path)