diff --git a/esmvalcore/iris_helpers.py b/esmvalcore/iris_helpers.py index 1191587546..eb9a96461e 100644 --- a/esmvalcore/iris_helpers.py +++ b/esmvalcore/iris_helpers.py @@ -146,10 +146,20 @@ def merge_cube_attributes( attributes.setdefault(attr, []) attributes[attr].append(val) - # Step 2: if values are not equal, first convert them to strings (so that + # Step 2: use the first cube in which an attribute occurs to decide if an + # attribute is global or local. + final_attributes = iris.cube.CubeAttrsDict() + for cube in cubes: + for attr, value in cube.attributes.locals.items(): + if attr not in final_attributes: + final_attributes.locals[attr] = value + for attr, value in cube.attributes.globals.items(): + if attr not in final_attributes: + final_attributes.globals[attr] = value + + # Step 3: if values are not equal, first convert them to strings (so that # set() can be used); then extract unique elements from this list, sort it, - # and use the delimiter to join all elements to a single string - final_attributes: Dict[str, NetCDFAttr] = {} + # and use the delimiter to join all elements to a single string. for (attr, vals) in attributes.items(): set_of_str = sorted({str(v) for v in vals}) if len(set_of_str) == 1: @@ -157,7 +167,7 @@ def merge_cube_attributes( else: final_attributes[attr] = delimiter.join(set_of_str) - # Step 3: modify the cubes in-place + # Step 4: modify the cubes in-place for cube in cubes: cube.attributes = final_attributes diff --git a/tests/unit/test_iris_helpers.py b/tests/unit/test_iris_helpers.py index e7b18e4c67..e49b6b803a 100644 --- a/tests/unit/test_iris_helpers.py +++ b/tests/unit/test_iris_helpers.py @@ -2,6 +2,7 @@ import datetime from copy import deepcopy from itertools import permutations +from pprint import pformat from unittest import mock import dask.array as da @@ -141,8 +142,10 @@ def test_date2num_scalar(date, dtype, expected, units): assert num.dtype == dtype -def assert_attribues_equal(attrs_1: dict, attrs_2: dict) -> None: +def assert_attributes_equal(attrs_1: dict, attrs_2: dict) -> None: """Check attributes using :func:`numpy.testing.assert_array_equal`.""" + print(pformat(dict(attrs_1))) + print(pformat(dict(attrs_2))) assert len(attrs_1) == len(attrs_2) for (attr, val) in attrs_1.items(): assert attr in attrs_2 @@ -210,7 +213,7 @@ def test_merge_cube_attributes(cubes): merge_cube_attributes(cubes) assert len(cubes) == 3 for cube in cubes: - assert_attribues_equal(cube.attributes, expected_attributes) + assert_attributes_equal(cube.attributes, expected_attributes) def test_merge_cube_attributes_0_cubes(): @@ -224,7 +227,20 @@ def test_merge_cube_attributes_1_cube(): expected_attributes = deepcopy(cubes[0].attributes) merge_cube_attributes(cubes) assert len(cubes) == 1 - assert_attribues_equal(cubes[0].attributes, expected_attributes) + assert_attributes_equal(cubes[0].attributes, expected_attributes) + + +def test_merge_cube_attributes_global_local(): + cube1 = CUBES[0].copy() + cube2 = CUBES[1].copy() + cube1.attributes.globals['attr1'] = 1 + cube1.attributes.globals['attr2'] = 1 + cube1.attributes.globals['attr3'] = 1 + cube2.attributes.locals['attr1'] = 1 + merge_cube_attributes([cube1, cube2]) + for cube in [cube1, cube2]: + for attr in ['attr1', 'attr2', 'attr3']: + assert attr in cube.attributes.globals @pytest.fixture