Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/iris/analysis/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,10 @@ def _math_op_common(
new_cube.data = ma.masked_array(0, 1, dtype=new_dtype)

iris.analysis.clear_phenomenon_identity(new_cube)
for cm in cube.cell_measures():
new_cube.remove_cell_measure(cm)
for av in cube.ancillary_variables():
new_cube.remove_ancillary_variable(av)
new_cube.units = new_unit
return new_cube

Expand Down
35 changes: 34 additions & 1 deletion lib/iris/tests/unit/analysis/maths/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy import ma

from iris.analysis import MEAN
from iris.coords import DimCoord
from iris.coords import DimCoord, CellMeasure, AncillaryVariable
from iris.cube import Cube
import iris.tests as tests
import iris.tests.stock as stock
Expand Down Expand Up @@ -212,3 +212,36 @@ def test_masked_constant_not_in_place(self):
self.assertMaskedArrayEqual(ma.masked_array(0, 1), res.data)
self.assertEqual(dtype, res.dtype)
self.assertIsNot(res, cube)


class CubeArithmeticAncillaryHandlingTestMixin(metaclass=ABCMeta):
@property
@abstractmethod
def cube_func(self):
# Define an iris arithmetic function to be called
# I.E. 'iris.analysis.maths.xx'.
pass

def test_cell_measure_removal(self):
cube1 = Cube([0])
cm = CellMeasure([0], long_name="cm")
cube1.add_cell_measure(cm)
cube2 = Cube([0])
res1 = cube1 + cube2
res2 = cube2 + cube1
res3 = cube1 + cube1
self.assertEqual(res1.cell_measures(), [])
self.assertEqual(res2.cell_measures(), [])
self.assertEqual(res3.cell_measures(), [])

def test_ancillary_removal(self):
cube1 = Cube([0])
av = AncillaryVariable([0], long_name="av")
cube1.add_ancillary_variable(av)
cube2 = Cube([0])
res1 = cube1 + cube2
res2 = cube2 + cube1
res3 = cube1 + cube1
self.assertEqual(res1.ancillary_variables(), [])
self.assertEqual(res2.ancillary_variables(), [])
self.assertEqual(res3.ancillary_variables(), [])
10 changes: 10 additions & 0 deletions lib/iris/tests/unit/analysis/maths/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CubeArithmeticCoordsTest,
CubeArithmeticMaskedConstantTestMixin,
CubeArithmeticMaskingTestMixin,
CubeArithmeticAncillaryHandlingTestMixin,
)


Expand Down Expand Up @@ -70,5 +71,14 @@ def cube_func(self):
return add


@tests.iristest_timing_decorator
class TestAncillaryHandling(
tests.IrisTest_nometa, CubeArithmeticAncillaryHandlingTestMixin
):
@property
def cube_func(self):
return add


if __name__ == "__main__":
tests.main()
10 changes: 10 additions & 0 deletions lib/iris/tests/unit/analysis/maths/test_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CubeArithmeticBroadcastingTestMixin,
CubeArithmeticMaskingTestMixin,
CubeArithmeticCoordsTest,
CubeArithmeticAncillaryHandlingTestMixin,
)


Expand Down Expand Up @@ -86,5 +87,14 @@ def test_reversed_points(self):
divide(cube1, cube2)


@tests.iristest_timing_decorator
class TestAncillaryHandling(
tests.IrisTest_nometa, CubeArithmeticAncillaryHandlingTestMixin
):
@property
def cube_func(self):
return divide


if __name__ == "__main__":
tests.main()
10 changes: 10 additions & 0 deletions lib/iris/tests/unit/analysis/maths/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CubeArithmeticCoordsTest,
CubeArithmeticMaskedConstantTestMixin,
CubeArithmeticMaskingTestMixin,
CubeArithmeticAncillaryHandlingTestMixin,
)


Expand Down Expand Up @@ -70,5 +71,14 @@ def cube_func(self):
return multiply


@tests.iristest_timing_decorator
class TestAncillaryHandling(
tests.IrisTest_nometa, CubeArithmeticAncillaryHandlingTestMixin
):
@property
def cube_func(self):
return multiply


if __name__ == "__main__":
tests.main()