Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion improver/cli/recursive_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def process(
*,
iterations: int = 1,
variable_mask: bool = False,
mask_zeros: bool = False,
):
"""Module to apply a recursive filter to neighbourhooded data.

Expand Down Expand Up @@ -50,6 +51,11 @@ def process(
different mask. If False and cube is masked, a check will be made that
the same mask is present on each spatial slice. If True, each spatial
slice of cube may contain a different spatial mask.
mask_zeros (bool):
If set true all of the values of 0 in the cube will be masked,
stopping the recursive filter from spreading values into these areas.
They will then be unmasked later on. If the input cube was masked
this mask will be reapplied to the output at the end.

Returns:
iris.cube.Cube:
Expand All @@ -59,5 +65,8 @@ def process(

plugin = RecursiveFilter(iterations=iterations)
return plugin(
cube, smoothing_coefficients=smoothing_coefficients, variable_mask=variable_mask
cube,
smoothing_coefficients=smoothing_coefficients,
variable_mask=variable_mask,
mask_zeros=mask_zeros,
)
9 changes: 5 additions & 4 deletions improver/nbhood/recursive_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,6 @@ def process(
else:
cube_mask = None

if mask_zeros:
cube.data = np.ma.masked_where(cube.data == 0.0, cube.data, copy=False)
# This masks any array element that is zero

cube_format = next(cube.slices([cube.coord(axis="y"), cube.coord(axis="x")]))
coeffs_x, coeffs_y = self._validate_coefficients(
cube_format, smoothing_coefficients
Expand All @@ -402,6 +398,11 @@ def process(
"Input cube contains spatial slices with different masks."
)

# This masks any array element that is zero. Performed after variable array
# check as zeros may not be located consistently across slices.
if mask_zeros:
cube.data = np.ma.masked_where(cube.data == 0.0, cube.data, copy=False)

recursed_cube = iris.cube.CubeList()
for cslice in cube.slices([cube.coord(axis="y"), cube.coord(axis="x")]):
padded_cube = pad_cube_with_halo(
Expand Down
2 changes: 2 additions & 0 deletions improver_tests/acceptance/SHA256SUMS
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,11 @@ f69103cececd76e27bbff5a96e9c74c0e708dcb7f18459ade3eb448639992b34 ./precipitatio
ae048c636992e80b79c6cbb44b36339b30ea8d0ef1db72cd3f4de8766346fa1d ./recursive-filter/input.nc
b6cdb8bf877bb0b3b78ad224b50b9272b65732bf9e39a88df704209e228bf4c0 ./recursive-filter/input_masked.nc
11c428f6fb0202ab0f975e58e52d17342c50f607aee4fd0e387a2a62c188790e ./recursive-filter/input_variable_masked.nc
ce0ff757524da235a94421a3e2a254958811c562c739a7a73c3fa9cd4d00ddb5 ./recursive-filter/input_with_zeros.nc
b4f7acb4fb95640f50a11cdb02038a1e66f2bc02d65b87230ff5a1dab649f141 ./recursive-filter/kgo_basic.nc
0f5ae62721603eb258cd54b195673be0fde0effbecbae7bf88100913a603bd38 ./recursive-filter/kgo_internal_mask_with_re_mask.nc
6e77a397bea8fd914a182ed9152b9b75b9e751968c432d1d45ffe971c6f8a815 ./recursive-filter/kgo_variable_internal_mask_with_re_mask.nc
a0660ea5d5540b9476e8b77d5f488db2f36842c0da3a1d068d52986353fe445f ./recursive-filter/kgo_with_zeros.nc
49497750007d283c609d8d1e1415f9a8be1f1c8b9f4c5c74e3b741ab3c3f681e ./recursive-filter/smoothing_coefficients.nc
f4128d6ea8da41f46eb3b86f49cd6a60a0c81dabced89f1016372c20230fd85d ./regrid/basic/kgo.nc
b431242e8abec923d1ad6d54022e3602511274e8b35bab81fa63a2383a9020b4 ./regrid/bilinear_2/kgo_multi_realization.nc
Expand Down
18 changes: 18 additions & 0 deletions improver_tests/acceptance/test_recursive_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,21 @@ def test_variable_internal_mask(tmp_path):
]
run_cli(args)
acc.compare(output_path, kgo_path)


def test_zero_masking(tmp_path):
"""Test recursive filter with zero masking preserves areas of zeros."""
kgo_dir = acc.kgo_root() / "recursive-filter"
kgo_path = kgo_dir / "kgo_with_zeros.nc"
input_path = kgo_dir / "input_with_zeros.nc"
smoothing_coefficients_path = kgo_dir / "smoothing_coefficients.nc"
output_path = tmp_path / "output.nc"
args = [
input_path,
smoothing_coefficients_path,
"--mask-zeros",
"--output",
output_path,
]
run_cli(args)
acc.compare(output_path, kgo_path)