diff --git a/lib/iris/common/metadata.py b/lib/iris/common/metadata.py index 9b1d3278f3..3b4f2b80d3 100644 --- a/lib/iris/common/metadata.py +++ b/lib/iris/common/metadata.py @@ -34,6 +34,7 @@ "AncillaryVariableMetadata", "BaseMetadata", "CellMeasureMetadata", + "ConnectivityMetadata", "CoordMetadata", "CubeMetadata", "DimCoordMetadata", @@ -181,9 +182,10 @@ def func(field): return result # Note that, for strict we use "_fields" not "_members". - # The "circular" member does not participate in strict equivalence. + # The "circular" and "src_dim" members do not participate in strict equivalence. fields = filter( - lambda field: field != "circular", self._fields + lambda field: field not in ("circular", "src_dim"), + self._fields, ) result = all([func(field) for field in fields]) @@ -875,6 +877,126 @@ def equal(self, other, lenient=None): return super().equal(other, lenient=lenient) +class ConnectivityMetadata(BaseMetadata): + """ + Metadata container for a :class:`~iris.coords.Connectivity`. + + """ + + # The "src_dim" member is stateful only, and does not participate in + # lenient/strict equivalence. + _members = ("cf_role", "start_index", "src_dim") + + __slots__ = () + + @wraps(BaseMetadata.__eq__, assigned=("__doc__",), updated=()) + @lenient_service + def __eq__(self, other): + return super().__eq__(other) + + def _combine_lenient(self, other): + """ + Perform lenient combination of metadata members for connectivities. + + Args: + + * other (ConnectivityMetadata): + The other connectivity metadata participating in the lenient + combination. + + Returns: + A list of combined metadata member values. + + """ + # Perform "strict" combination for "cf_role", "start_index", "src_dim". + def func(field): + left = getattr(self, field) + right = getattr(other, field) + return left if left == right else None + + # Note that, we use "_members" not "_fields". + values = [func(field) for field in ConnectivityMetadata._members] + # Perform lenient combination of the other parent members. + result = super()._combine_lenient(other) + result.extend(values) + + return result + + def _compare_lenient(self, other): + """ + Perform lenient equality of metadata members for connectivities. + + Args: + + * other (ConnectivityMetadata): + The other connectivity metadata participating in the lenient + comparison. + + Returns: + Boolean. + + """ + # Perform "strict" comparison for "cf_role", "start_index". + # The "src_dim" member is not part of lenient equivalence. + members = filter( + lambda member: member != "src_dim", ConnectivityMetadata._members + ) + result = all( + [ + getattr(self, field) == getattr(other, field) + for field in members + ] + ) + if result: + # Perform lenient comparison of the other parent members. + result = super()._compare_lenient(other) + + return result + + def _difference_lenient(self, other): + """ + Perform lenient difference of metadata members for connectivities. + + Args: + + * other (ConnectivityMetadata): + The other connectivity metadata participating in the lenient + difference. + + Returns: + A list of difference metadata member values. + + """ + # Perform "strict" difference for "cf_role", "start_index", "src_dim". + def func(field): + left = getattr(self, field) + right = getattr(other, field) + return None if left == right else (left, right) + + # Note that, we use "_members" not "_fields". + values = [func(field) for field in ConnectivityMetadata._members] + # Perform lenient difference of the other parent members. + result = super()._difference_lenient(other) + result.extend(values) + + return result + + @wraps(BaseMetadata.combine, assigned=("__doc__",), updated=()) + @lenient_service + def combine(self, other, lenient=None): + return super().combine(other, lenient=lenient) + + @wraps(BaseMetadata.difference, assigned=("__doc__",), updated=()) + @lenient_service + def difference(self, other, lenient=None): + return super().difference(other, lenient=lenient) + + @wraps(BaseMetadata.equal, assigned=("__doc__",), updated=()) + @lenient_service + def equal(self, other, lenient=None): + return super().equal(other, lenient=lenient) + + class CoordMetadata(BaseMetadata): """ Metadata container for a :class:`~iris.coords.Coord`. @@ -1459,6 +1581,7 @@ def values(self): AncillaryVariableMetadata.combine, BaseMetadata.combine, CellMeasureMetadata.combine, + ConnectivityMetadata.combine, CoordMetadata.combine, CubeMetadata.combine, DimCoordMetadata.combine, @@ -1470,6 +1593,7 @@ def values(self): AncillaryVariableMetadata.difference, BaseMetadata.difference, CellMeasureMetadata.difference, + ConnectivityMetadata.difference, CoordMetadata.difference, CubeMetadata.difference, DimCoordMetadata.difference, @@ -1484,6 +1608,8 @@ def values(self): BaseMetadata.equal, CellMeasureMetadata.__eq__, CellMeasureMetadata.equal, + ConnectivityMetadata.__eq__, + ConnectivityMetadata.equal, CoordMetadata.__eq__, CoordMetadata.equal, CubeMetadata.__eq__, diff --git a/lib/iris/coords.py b/lib/iris/coords.py index 76ca83cd96..1e91bd0448 100644 --- a/lib/iris/coords.py +++ b/lib/iris/coords.py @@ -19,6 +19,7 @@ import zlib import cftime +import dask.array as da import numpy as np import numpy.ma as ma @@ -30,6 +31,7 @@ BaseMetadata, CFVariableMixin, CellMeasureMetadata, + ConnectivityMetadata, CoordMetadata, DimCoordMetadata, metadata_manager_factory, @@ -624,6 +626,8 @@ def xml_element(self, doc): # otherwise. if isinstance(self, Coord): values_term = "points" + elif isinstance(self, Connectivity): + values_term = "indices" else: values_term = "data" element.setAttribute(values_term, self._xml_array_repr(self._values)) @@ -1914,7 +1918,6 @@ def collapsed(self, dims_to_collapse=None): Replaces the points & bounds with a simple bounded region. """ - import dask.array as da # Ensure dims_to_collapse is a tuple to be able to pass # through to numpy @@ -2775,6 +2778,467 @@ def xml_element(self, doc): return cellMethod_xml_element +class Connectivity(_DimensionalMetadata): + """ + A CF-UGRID topology connectivity, describing the topological relationship + between two lists of dimensional locations. One or more connectivities + make up a CF-UGRID topology - a constituent of a CF-UGRID mesh. + + See: https://ugrid-conventions.github.io/ugrid-conventions + + """ + + UGRID_CF_ROLES = [ + "edge_node_connectivity", + "face_node_connectivity", + "face_edge_connectivity", + "face_face_connectivity", + "edge_face_connectivity", + "boundary_node_connectivity", + "volume_node_connectivity", + "volume_edge_connectivity", + "volume_face_connectivity", + "volume_volume_connectivity", + ] + + def __init__( + self, + indices, + cf_role, + standard_name=None, + long_name=None, + var_name=None, + units=None, + attributes=None, + start_index=0, + src_dim=0, + ): + """ + Constructs a single connectivity. + + Args: + + * indices (numpy.ndarray or numpy.ma.core.MaskedArray or dask.array.Array): + The index values describing a topological relationship. Constructed + of 2 dimensions - the list of locations, and within each location: + the indices of the 'target locations' it relates to. + Use a :class:`numpy.ma.core.MaskedArray` if :attr:`src_location` + lengths vary - mask unused index 'slots' within each + :attr:`src_location`. Use a :class:`dask.array.Array` to keep + indices 'lazy'. + * cf_role (str): + Denotes the topological relationship that this connectivity + describes. Made up of this array's locations, and the indexed + 'target location' within each location. + See :attr:`UGRID_CF_ROLES` for valid arguments. + + Kwargs: + + * standard_name (str): + CF standard name of the connectivity. + (NOTE: this is not expected by the UGRID conventions, but will be + handled in Iris' standard way if provided). + * long_name (str): + Descriptive name of the connectivity. + * var_name (str): + The netCDF variable name for the connectivity. + * units (cf_units.Unit): + The :class:`~cf_units.Unit` of the connectivity's values. + Can be a string, which will be converted to a Unit object. + (NOTE: this is not expected by the UGRID conventions, but will be + handled in Iris' standard way if provided). + * attributes (dict): + A dictionary containing other cf and user-defined attributes. + * start_index (int): + Either ``0`` or ``1``. Default is ``0``. Denotes whether + :attr:`indices` uses 0-based or 1-based indexing (allows support + for Fortran and legacy NetCDF files). + * src_dim (int): + Either ``0`` or ``1``. Default is ``0``. Denotes which dimension + of :attr:`indices` varies over the :attr:`src_location`'s (the + alternate dimension therefore varying within individual + :attr:`src_location`'s). (This parameter allows support for fastest varying index being + either first or last). + E.g. for ``face_node_connectivity``, for 10 faces: + ``indices.shape[src_dim] = 10``. + + """ + + def validate_arg_vs_list(arg_name, arg, valid_list): + if arg not in valid_list: + error_msg = ( + f"Invalid {arg_name} . Got: {arg} . Must be one of: " + f"{valid_list} ." + ) + raise ValueError(error_msg) + + # Configure the metadata manager. + self._metadata_manager = metadata_manager_factory(ConnectivityMetadata) + + validate_arg_vs_list("start_index", start_index, [0, 1]) + # indices array will be 2-dimensional, so must be either 0 or 1. + validate_arg_vs_list("src_dim", src_dim, [0, 1]) + validate_arg_vs_list("cf_role", cf_role, Connectivity.UGRID_CF_ROLES) + + self._metadata_manager.start_index = start_index + self._metadata_manager.src_dim = src_dim + self._metadata_manager.cf_role = cf_role + + self._tgt_dim = 1 - src_dim + self._src_location, self._tgt_location = cf_role.split("_")[:2] + + super().__init__( + values=indices, + standard_name=standard_name, + long_name=long_name, + var_name=var_name, + units=units, + attributes=attributes, + ) + + @property + def _values(self): + # Overridden just to allow .setter override. + return super()._values + + @_values.setter + def _values(self, values): + self._validate_indices(values, shapes_only=True) + # The recommended way of using the setter in super(). + super(Connectivity, self.__class__)._values.fset(self, values) + + @property + def cf_role(self): + """ + The category of topological relationship that this connectivity + describes. + **Read-only** - validity of :attr:`indices` is dependent on + :attr:`cf_role`. A new :class:`Connectivity` must therefore be defined + if a different :attr:`cf_role` is needed. + + """ + return self._metadata_manager.cf_role + + @property + def src_location(self): + """ + Derived from the connectivity's :attr:`cf_role` - the first part, e.g. + ``face`` in ``face_node_connectivity``. Refers to the locations + listed by the :attr:`src_dim` of the connectivity's :attr:`indices` + array. + + """ + return self._src_location + + @property + def tgt_location(self): + """ + Derived from the connectivity's :attr:`cf_role` - the second part, e.g. + ``node`` in ``face_node_connectivity``. Refers to the locations indexed + by the values in the connectivity's :attr:`indices` array. + + """ + return self._tgt_location + + @property + def start_index(self): + """ + The base value of the connectivity's :attr:`indices` array; either + ``0`` or ``1``. + **Read-only** - validity of :attr:`indices` is dependent on + :attr:`start_index`. A new :class:`Connectivity` must therefore be + defined if a different :attr:`start_index` is needed. + + """ + return self._metadata_manager.start_index + + @property + def src_dim(self): + """ + The dimension of the connectivity's :attr:`indices` array that varies + over the connectivity's :attr:`src_location`'s. Either ``0`` or ``1``. + **Read-only** - validity of :attr:`indices` is dependent on + :attr:`src_dim`. Use :meth:`transpose` to create a new, transposed + :class:`Connectivity` if a different :attr:`src_dim` is needed. + + """ + return self._metadata_manager.src_dim + + @property + def tgt_dim(self): + """ + Derived as the alternate value of :attr:`src_dim` - each must equal + either ``0`` or ``1``. + The dimension of the connectivity's :attr:`indices` array that varies + within the connectivity's individual :attr:`src_location`'s. + + """ + return self._tgt_dim + + @property + def indices(self): + """ + The index values describing the topological relationship of the + connectivity, as a NumPy array. Masked points indicate a + :attr:`src_location` shorter than the longest :attr:`src_location` + described in this array - unused index 'slots' are masked. + **Read-only** - index values are only meaningful when combined with + an appropriate :attr:`cf_role`, :attr:`start_index` and + :attr:`src_dim`. A new :class:`Connectivity` must therefore be + defined if different indices are needed. + + """ + return self._values + + def indices_by_src(self, indices=None): + """ + Return a view of the indices array with :attr:`src_dim` **always** as + the first index - transposed if necessary. Can optionally pass in an + identically shaped array on which to perform this operation (e.g. the + output from :meth:`core_indices` or :meth:`lazy_indices`). + + Kwargs: + + * indices (array): + The array on which to operate. If ``None``, will operate on + :attr:`indices`. Default is ``None``. + + Returns: + A view of the indices array, transposed - if necessary - to put + :attr:`src_dim` first. + + """ + if indices is None: + indices = self.indices + + if indices.shape != self.shape: + raise ValueError( + f"Invalid indices provided. Must be shape={self.shape} , " + f"got shape={indices.shape} ." + ) + + if self.src_dim == 0: + result = indices + elif self.src_dim == 1: + result = indices.transpose() + else: + raise ValueError("Invalid src_dim.") + + return result + + def _validate_indices(self, indices, shapes_only=False): + # Use shapes_only=True for a lower resource, less thorough validation + # of indices by just inspecting the array shape instead of inspecting + # individual masks. So will not catch individual src_locations being + # unacceptably small. + + def indices_error(message): + raise ValueError("Invalid indices provided. " + message) + + indices = self._sanitise_array(indices, 0) + + indices_dtype = indices.dtype + if not np.issubdtype(indices_dtype, np.integer): + indices_error( + f"dtype must be numpy integer subtype, got: {indices_dtype} ." + ) + + indices_min = indices.min() + if _lazy.is_lazy_data(indices_min): + indices_min = indices_min.compute() + if indices_min < self.start_index: + indices_error( + f"Lowest index: {indices_min} < start_index: {self.start_index} ." + ) + + indices_shape = indices.shape + if len(indices_shape) != 2: + indices_error( + f"Expected 2-dimensional shape, got: shape={indices_shape} ." + ) + + len_req_fail = False + if shapes_only: + src_shape = indices_shape[self.tgt_dim] + # Wrap as lazy to allow use of the same operations below + # regardless of shapes_only. + src_lengths = _lazy.as_lazy_data(np.asarray(src_shape)) + else: + # Wouldn't be safe to use during __init__ validation, since + # lazy_src_lengths requires self.indices to exist. Safe here since + # shapes_only==False is only called manually, i.e. after + # initialisation. + src_lengths = self.lazy_src_lengths() + if self.src_location in ("edge", "boundary"): + if (src_lengths != 2).any().compute(): + len_req_fail = "len=2" + else: + if self.src_location == "face": + min_size = 3 + elif self.src_location == "volume": + if self.tgt_location == "edge": + min_size = 6 + else: + min_size = 4 + else: + raise NotImplementedError + if (src_lengths < min_size).any().compute(): + len_req_fail = f"len>={min_size}" + if len_req_fail: + indices_error( + f"Not all src_locations meet requirement: {len_req_fail} - " + f"needed to describe '{self.cf_role}' ." + ) + + def validate_indices(self): + """ + Perform a thorough validity check of this connectivity's + :attr:`indices`. Includes checking the sizes of individual + :attr:`src_location`'s (specified using masks on the + :attr:`indices` array) against the :attr:`cf_role`. + + Raises a ``ValueError`` if any problems are encountered, otherwise + passes silently. + + .. note:: + + While this uses lazy computation, it will still be a high + resource demand for a large :attr:`indices` array. + + """ + self._validate_indices(self.indices, shapes_only=False) + + def __eq__(self, other): + eq = NotImplemented + if isinstance(other, Connectivity): + # Account for the fact that other could be the transposed equivalent + # of self, which we consider 'safe' since the recommended + # interaction with the indices array is via indices_by_src, which + # corrects for this difference. (To enable this, src_dim does + # not participate in ConnectivityMetadata to ConnectivityMetadata + # equivalence). + if hasattr(other, "metadata"): + # metadata comparison + eq = self.metadata == other.metadata + if eq: + eq = ( + self.indices_by_src() == other.indices_by_src() + ).all() + return eq + + def transpose(self): + """ + Create a new :class:`Connectivity`, identical to this one but with the + :attr:`indices` array transposed and the :attr:`src_dim` value flipped. + + Returns: + A new :class:`Connectivity` that is the transposed equivalent of + the original. + + """ + new_connectivity = Connectivity( + indices=self.indices.transpose().copy(), + cf_role=self.cf_role, + standard_name=self.standard_name, + long_name=self.long_name, + var_name=self.var_name, + units=self.units, + attributes=self.attributes, + start_index=self.start_index, + src_dim=self.tgt_dim, + ) + return new_connectivity + + def lazy_indices(self): + """ + Return a lazy array representing the connectivity's indices. + + Accessing this method will never cause the :attr:`indices` values to be + loaded. Similarly, calling methods on, or indexing, the returned Array + will not cause the connectivity to have loaded :attr:`indices`. + + If the :attr:`indices` have already been loaded for the connectivity, + the returned Array will be a new lazy array wrapper. + + Returns: + A lazy array, representing the connectivity indices array. + + """ + return super()._lazy_values() + + def core_indices(self): + """ + The indices array at the core of this connectivity, which may be a + NumPy array or a Dask array. + + Returns: + numpy.ndarray or numpy.ma.core.MaskedArray or dask.array.Array + + """ + return super()._core_values() + + def has_lazy_indices(self): + """ + Return a boolean indicating whether the connectivity's :attr:`indices` + array is a lazy Dask array or not. + + Returns: + boolean + + """ + return super()._has_lazy_values() + + def lazy_src_lengths(self): + """ + Return a lazy array representing the lengths of each + :attr:`src_location` in the :attr:`src_dim` of the connectivity's + :attr:`indices` array, accounting for masks if present. + + Accessing this method will never cause the :attr:`indices` values to be + loaded. Similarly, calling methods on, or indexing, the returned Array + will not cause the connectivity to have loaded :attr:`indices`. + + The returned Array will be lazy regardless of whether the + :attr:`indices` have already been loaded. + + Returns: + A lazy array, representing the lengths of each :attr:`src_location`. + + """ + src_mask_counts = da.sum( + da.ma.getmaskarray(self.indices), axis=self.tgt_dim + ) + max_src_size = self.indices.shape[self.tgt_dim] + return max_src_size - src_mask_counts + + def src_lengths(self): + """ + Return a NumPy array representing the lengths of each + :attr:`src_location` in the :attr:`src_dim` of the connectivity's + :attr:`indices` array, accounting for masks if present. + + Returns: + A NumPy array, representing the lengths of each :attr:`src_location`. + + """ + return self.lazy_src_lengths().compute() + + def cube_dims(self, cube): + """Not available on :class:`Connectivity`.""" + raise NotImplementedError + + def xml_element(self, doc): + # Create the XML element as the camelCaseEquivalent of the + # class name + element = super().xml_element(doc) + + element.setAttribute("cf_role", self.cf_role) + element.setAttribute("start_index", self.start_index) + element.setAttribute("src_dim", self.src_dim) + + return element + + # See Coord.cells() for the description/context. class _CellIterator(Iterator): def __init__(self, coord): diff --git a/lib/iris/tests/unit/common/metadata/test_ConnectivityMetadata.py b/lib/iris/tests/unit/common/metadata/test_ConnectivityMetadata.py new file mode 100644 index 0000000000..c09cdc429c --- /dev/null +++ b/lib/iris/tests/unit/common/metadata/test_ConnectivityMetadata.py @@ -0,0 +1,774 @@ +# Copyright Iris contributors +# +# This file is part of Iris and is released under the LGPL license. +# See COPYING and COPYING.LESSER in the root of the repository for full +# licensing details. +""" +Unit tests for the :class:`iris.common.metadata.ConnectivityMetadata`. + +""" + +# Import iris.tests first so that some things can be initialised before +# importing anything else. +import iris.tests as tests + +from copy import deepcopy +import unittest.mock as mock +from unittest.mock import sentinel + +from iris.common.lenient import _LENIENT, _qualname +from iris.common.metadata import BaseMetadata, ConnectivityMetadata + + +class Test(tests.IrisTest): + def setUp(self): + self.standard_name = mock.sentinel.standard_name + self.long_name = mock.sentinel.long_name + self.var_name = mock.sentinel.var_name + self.units = mock.sentinel.units + self.attributes = mock.sentinel.attributes + self.cf_role = mock.sentinel.cf_role + self.start_index = mock.sentinel.start_index + self.src_dim = mock.sentinel.src_dim + self.cls = ConnectivityMetadata + + def test_repr(self): + metadata = self.cls( + standard_name=self.standard_name, + long_name=self.long_name, + var_name=self.var_name, + units=self.units, + attributes=self.attributes, + cf_role=self.cf_role, + start_index=self.start_index, + src_dim=self.src_dim, + ) + fmt = ( + "ConnectivityMetadata(standard_name={!r}, long_name={!r}, " + "var_name={!r}, units={!r}, attributes={!r}, cf_role={!r}, " + "start_index={!r}, src_dim={!r})" + ) + expected = fmt.format( + self.standard_name, + self.long_name, + self.var_name, + self.units, + self.attributes, + self.cf_role, + self.start_index, + self.src_dim, + ) + self.assertEqual(expected, repr(metadata)) + + def test__fields(self): + expected = ( + "standard_name", + "long_name", + "var_name", + "units", + "attributes", + "cf_role", + "start_index", + "src_dim", + ) + self.assertEqual(self.cls._fields, expected) + + def test_bases(self): + self.assertTrue(issubclass(self.cls, BaseMetadata)) + + +class Test__eq__(tests.IrisTest): + def setUp(self): + self.values = dict( + standard_name=sentinel.standard_name, + long_name=sentinel.long_name, + var_name=sentinel.var_name, + units=sentinel.units, + attributes=sentinel.attributes, + cf_role=sentinel.cf_role, + start_index=sentinel.start_index, + src_dim=sentinel.src_dim, + ) + self.dummy = sentinel.dummy + self.cls = ConnectivityMetadata + # The "src_dim" member is stateful only, and does not participate in + # lenient/strict equivalence. + self.members_no_src_dim = filter( + lambda member: member != "src_dim", self.cls._members + ) + + def test_wraps_docstring(self): + self.assertEqual(BaseMetadata.__eq__.__doc__, self.cls.__eq__.__doc__) + + def test_lenient_service(self): + qualname___eq__ = _qualname(self.cls.__eq__) + self.assertIn(qualname___eq__, _LENIENT) + self.assertTrue(_LENIENT[qualname___eq__]) + self.assertTrue(_LENIENT[self.cls.__eq__]) + + def test_call(self): + other = sentinel.other + return_value = sentinel.return_value + metadata = self.cls(*(None,) * len(self.cls._fields)) + with mock.patch.object( + BaseMetadata, "__eq__", return_value=return_value + ) as mocker: + result = metadata.__eq__(other) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(), kwargs) + + def test_op_lenient_same(self): + lmetadata = self.cls(**self.values) + rmetadata = self.cls(**self.values) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + def test_op_lenient_same_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["var_name"] = None + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + def test_op_lenient_same_members_none(self): + for member in self.members_no_src_dim: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = None + rmetadata = self.cls(**right) + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=True + ): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_lenient_same_src_dim_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["src_dim"] = None + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + def test_op_lenient_different(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["units"] = self.dummy + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_lenient_different_members(self): + for member in self.members_no_src_dim: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = self.dummy + rmetadata = self.cls(**right) + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=True + ): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_lenient_different_src_dim(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["src_dim"] = self.dummy + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + def test_op_strict_same(self): + lmetadata = self.cls(**self.values) + rmetadata = self.cls(**self.values) + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + def test_op_strict_different(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["long_name"] = self.dummy + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_strict_different_members(self): + for member in self.members_no_src_dim: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = self.dummy + rmetadata = self.cls(**right) + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=False + ): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_strict_different_src_dim(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["src_dim"] = self.dummy + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + def test_op_strict_different_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["long_name"] = None + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_strict_different_members_none(self): + for member in self.members_no_src_dim: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = None + rmetadata = self.cls(**right) + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=False + ): + self.assertFalse(lmetadata.__eq__(rmetadata)) + self.assertFalse(rmetadata.__eq__(lmetadata)) + + def test_op_strict_different_src_dim_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["src_dim"] = None + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertTrue(lmetadata.__eq__(rmetadata)) + self.assertTrue(rmetadata.__eq__(lmetadata)) + + +class Test___lt__(tests.IrisTest): + def setUp(self): + self.cls = ConnectivityMetadata + self.one = self.cls(1, 1, 1, 1, 1, 1, 1, 1) + self.two = self.cls(1, 1, 1, 2, 1, 1, 1, 1) + self.none = self.cls(1, 1, 1, None, 1, 1, 1, 1) + self.attributes = self.cls(1, 1, 1, 1, 10, 1, 1, 1) + + def test__ascending_lt(self): + result = self.one < self.two + self.assertTrue(result) + + def test__descending_lt(self): + result = self.two < self.one + self.assertFalse(result) + + def test__none_rhs_operand(self): + result = self.one < self.none + self.assertFalse(result) + + def test__none_lhs_operand(self): + result = self.none < self.one + self.assertTrue(result) + + def test__ignore_attributes(self): + result = self.one < self.attributes + self.assertFalse(result) + result = self.attributes < self.one + self.assertFalse(result) + + +class Test_combine(tests.IrisTest): + def setUp(self): + self.values = dict( + standard_name=sentinel.standard_name, + long_name=sentinel.long_name, + var_name=sentinel.var_name, + units=sentinel.units, + attributes=sentinel.attributes, + cf_role=sentinel.cf_role, + start_index=sentinel.start_index, + src_dim=sentinel.src_dim, + ) + self.dummy = sentinel.dummy + self.cls = ConnectivityMetadata + self.none = self.cls(*(None,) * len(self.cls._fields)) + + def test_wraps_docstring(self): + self.assertEqual( + BaseMetadata.combine.__doc__, self.cls.combine.__doc__ + ) + + def test_lenient_service(self): + qualname_combine = _qualname(self.cls.combine) + self.assertIn(qualname_combine, _LENIENT) + self.assertTrue(_LENIENT[qualname_combine]) + self.assertTrue(_LENIENT[self.cls.combine]) + + def test_lenient_default(self): + other = sentinel.other + return_value = sentinel.return_value + with mock.patch.object( + BaseMetadata, "combine", return_value=return_value + ) as mocker: + result = self.none.combine(other) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(lenient=None), kwargs) + + def test_lenient(self): + other = sentinel.other + lenient = sentinel.lenient + return_value = sentinel.return_value + with mock.patch.object( + BaseMetadata, "combine", return_value=return_value + ) as mocker: + result = self.none.combine(other, lenient=lenient) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(lenient=lenient), kwargs) + + def test_op_lenient_same(self): + lmetadata = self.cls(**self.values) + rmetadata = self.cls(**self.values) + expected = self.values + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertEqual(expected, lmetadata.combine(rmetadata)._asdict()) + self.assertEqual(expected, rmetadata.combine(lmetadata)._asdict()) + + def test_op_lenient_same_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["var_name"] = None + rmetadata = self.cls(**right) + expected = self.values + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertEqual(expected, lmetadata.combine(rmetadata)._asdict()) + self.assertEqual(expected, rmetadata.combine(lmetadata)._asdict()) + + def test_op_lenient_same_members_none(self): + for member in self.cls._members: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = None + rmetadata = self.cls(**right) + expected = right.copy() + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=True + ): + self.assertTrue( + expected, lmetadata.combine(rmetadata)._asdict() + ) + self.assertTrue( + expected, rmetadata.combine(lmetadata)._asdict() + ) + + def test_op_lenient_different(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["units"] = self.dummy + rmetadata = self.cls(**right) + expected = self.values.copy() + expected["units"] = None + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertEqual(expected, lmetadata.combine(rmetadata)._asdict()) + self.assertEqual(expected, rmetadata.combine(lmetadata)._asdict()) + + def test_op_lenient_different_members(self): + for member in self.cls._members: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = self.dummy + rmetadata = self.cls(**right) + expected = self.values.copy() + expected[member] = None + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=True + ): + self.assertEqual( + expected, lmetadata.combine(rmetadata)._asdict() + ) + self.assertEqual( + expected, rmetadata.combine(lmetadata)._asdict() + ) + + def test_op_strict_same(self): + lmetadata = self.cls(**self.values) + rmetadata = self.cls(**self.values) + expected = self.values.copy() + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertEqual(expected, lmetadata.combine(rmetadata)._asdict()) + self.assertEqual(expected, rmetadata.combine(lmetadata)._asdict()) + + def test_op_strict_different(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["long_name"] = self.dummy + rmetadata = self.cls(**right) + expected = self.values.copy() + expected["long_name"] = None + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertEqual(expected, lmetadata.combine(rmetadata)._asdict()) + self.assertEqual(expected, rmetadata.combine(lmetadata)._asdict()) + + def test_op_strict_different_members(self): + for member in self.cls._members: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = self.dummy + rmetadata = self.cls(**right) + expected = self.values.copy() + expected[member] = None + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=False + ): + self.assertEqual( + expected, lmetadata.combine(rmetadata)._asdict() + ) + self.assertEqual( + expected, rmetadata.combine(lmetadata)._asdict() + ) + + def test_op_strict_different_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["long_name"] = None + rmetadata = self.cls(**right) + expected = self.values.copy() + expected["long_name"] = None + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertEqual(expected, lmetadata.combine(rmetadata)._asdict()) + self.assertEqual(expected, rmetadata.combine(lmetadata)._asdict()) + + def test_op_strict_different_members_none(self): + for member in self.cls._members: + lmetadata = self.cls(**self.values) + right = self.values.copy() + right[member] = None + rmetadata = self.cls(**right) + expected = self.values.copy() + expected[member] = None + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=False + ): + self.assertEqual( + expected, lmetadata.combine(rmetadata)._asdict() + ) + self.assertEqual( + expected, rmetadata.combine(lmetadata)._asdict() + ) + + +class Test_difference(tests.IrisTest): + def setUp(self): + self.values = dict( + standard_name=sentinel.standard_name, + long_name=sentinel.long_name, + var_name=sentinel.var_name, + units=sentinel.units, + attributes=sentinel.attributes, + cf_role=sentinel.cf_role, + start_index=sentinel.start_index, + src_dim=sentinel.src_dim, + ) + self.dummy = sentinel.dummy + self.cls = ConnectivityMetadata + self.none = self.cls(*(None,) * len(self.cls._fields)) + + def test_wraps_docstring(self): + self.assertEqual( + BaseMetadata.difference.__doc__, self.cls.difference.__doc__ + ) + + def test_lenient_service(self): + qualname_difference = _qualname(self.cls.difference) + self.assertIn(qualname_difference, _LENIENT) + self.assertTrue(_LENIENT[qualname_difference]) + self.assertTrue(_LENIENT[self.cls.difference]) + + def test_lenient_default(self): + other = sentinel.other + return_value = sentinel.return_value + with mock.patch.object( + BaseMetadata, "difference", return_value=return_value + ) as mocker: + result = self.none.difference(other) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(lenient=None), kwargs) + + def test_lenient(self): + other = sentinel.other + lenient = sentinel.lenient + return_value = sentinel.return_value + with mock.patch.object( + BaseMetadata, "difference", return_value=return_value + ) as mocker: + result = self.none.difference(other, lenient=lenient) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(lenient=lenient), kwargs) + + def test_op_lenient_same(self): + lmetadata = self.cls(**self.values) + rmetadata = self.cls(**self.values) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertIsNone(lmetadata.difference(rmetadata)) + self.assertIsNone(rmetadata.difference(lmetadata)) + + def test_op_lenient_same_none(self): + lmetadata = self.cls(**self.values) + right = self.values.copy() + right["var_name"] = None + rmetadata = self.cls(**right) + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertIsNone(lmetadata.difference(rmetadata)) + self.assertIsNone(rmetadata.difference(lmetadata)) + + def test_op_lenient_same_members_none(self): + for member in self.cls._members: + lmetadata = self.cls(**self.values) + member_value = getattr(lmetadata, member) + right = self.values.copy() + right[member] = None + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected[member] = (member_value, None) + rexpected = deepcopy(self.none)._asdict() + rexpected[member] = (None, member_value) + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=True + ): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + def test_op_lenient_different(self): + left = self.values.copy() + lmetadata = self.cls(**left) + right = self.values.copy() + right["units"] = self.dummy + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected["units"] = (left["units"], right["units"]) + rexpected = deepcopy(self.none)._asdict() + rexpected["units"] = lexpected["units"][::-1] + + with mock.patch("iris.common.metadata._LENIENT", return_value=True): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + def test_op_lenient_different_members(self): + for member in self.cls._members: + left = self.values.copy() + lmetadata = self.cls(**left) + right = self.values.copy() + right[member] = self.dummy + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected[member] = (left[member], right[member]) + rexpected = deepcopy(self.none)._asdict() + rexpected[member] = lexpected[member][::-1] + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=True + ): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + def test_op_strict_same(self): + lmetadata = self.cls(**self.values) + rmetadata = self.cls(**self.values) + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertIsNone(lmetadata.difference(rmetadata)) + self.assertIsNone(rmetadata.difference(lmetadata)) + + def test_op_strict_different(self): + left = self.values.copy() + lmetadata = self.cls(**left) + right = self.values.copy() + right["long_name"] = self.dummy + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected["long_name"] = (left["long_name"], right["long_name"]) + rexpected = deepcopy(self.none)._asdict() + rexpected["long_name"] = lexpected["long_name"][::-1] + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + def test_op_strict_different_members(self): + for member in self.cls._members: + left = self.values.copy() + lmetadata = self.cls(**left) + right = self.values.copy() + right[member] = self.dummy + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected[member] = (left[member], right[member]) + rexpected = deepcopy(self.none)._asdict() + rexpected[member] = lexpected[member][::-1] + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=False + ): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + def test_op_strict_different_none(self): + left = self.values.copy() + lmetadata = self.cls(**left) + right = self.values.copy() + right["long_name"] = None + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected["long_name"] = (left["long_name"], right["long_name"]) + rexpected = deepcopy(self.none)._asdict() + rexpected["long_name"] = lexpected["long_name"][::-1] + + with mock.patch("iris.common.metadata._LENIENT", return_value=False): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + def test_op_strict_different_members_none(self): + for member in self.cls._members: + left = self.values.copy() + lmetadata = self.cls(**left) + right = self.values.copy() + right[member] = None + rmetadata = self.cls(**right) + lexpected = deepcopy(self.none)._asdict() + lexpected[member] = (left[member], right[member]) + rexpected = deepcopy(self.none)._asdict() + rexpected[member] = lexpected[member][::-1] + + with mock.patch( + "iris.common.metadata._LENIENT", return_value=False + ): + self.assertEqual( + lexpected, lmetadata.difference(rmetadata)._asdict() + ) + self.assertEqual( + rexpected, rmetadata.difference(lmetadata)._asdict() + ) + + +class Test_equal(tests.IrisTest): + def setUp(self): + self.cls = ConnectivityMetadata + self.none = self.cls(*(None,) * len(self.cls._fields)) + + def test_wraps_docstring(self): + self.assertEqual(BaseMetadata.equal.__doc__, self.cls.equal.__doc__) + + def test_lenient_service(self): + qualname_equal = _qualname(self.cls.equal) + self.assertIn(qualname_equal, _LENIENT) + self.assertTrue(_LENIENT[qualname_equal]) + self.assertTrue(_LENIENT[self.cls.equal]) + + def test_lenient_default(self): + other = sentinel.other + return_value = sentinel.return_value + with mock.patch.object( + BaseMetadata, "equal", return_value=return_value + ) as mocker: + result = self.none.equal(other) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(lenient=None), kwargs) + + def test_lenient(self): + other = sentinel.other + lenient = sentinel.lenient + return_value = sentinel.return_value + with mock.patch.object( + BaseMetadata, "equal", return_value=return_value + ) as mocker: + result = self.none.equal(other, lenient=lenient) + + self.assertEqual(return_value, result) + self.assertEqual(1, mocker.call_count) + (arg,), kwargs = mocker.call_args + self.assertEqual(other, arg) + self.assertEqual(dict(lenient=lenient), kwargs) + + +if __name__ == "__main__": + tests.main() diff --git a/lib/iris/tests/unit/common/metadata/test_metadata_manager_factory.py b/lib/iris/tests/unit/common/metadata/test_metadata_manager_factory.py index 6678aca446..b24dc8b991 100644 --- a/lib/iris/tests/unit/common/metadata/test_metadata_manager_factory.py +++ b/lib/iris/tests/unit/common/metadata/test_metadata_manager_factory.py @@ -21,6 +21,7 @@ AncillaryVariableMetadata, BaseMetadata, CellMeasureMetadata, + ConnectivityMetadata, CoordMetadata, CubeMetadata, metadata_manager_factory, @@ -31,6 +32,7 @@ AncillaryVariableMetadata, BaseMetadata, CellMeasureMetadata, + ConnectivityMetadata, CoordMetadata, CubeMetadata, ] diff --git a/lib/iris/tests/unit/common/mixin/test_CFVariableMixin.py b/lib/iris/tests/unit/common/mixin/test_CFVariableMixin.py index 5ac9361e4f..3b30282a9e 100644 --- a/lib/iris/tests/unit/common/mixin/test_CFVariableMixin.py +++ b/lib/iris/tests/unit/common/mixin/test_CFVariableMixin.py @@ -21,6 +21,7 @@ AncillaryVariableMetadata, BaseMetadata, CellMeasureMetadata, + ConnectivityMetadata, CoordMetadata, CubeMetadata, ) @@ -284,6 +285,19 @@ def test_class_cellmeasuremetadata(self): self.item._metadata_manager.attributes, metadata.attributes ) + def test_class_connectivitymetadata(self): + self.args.update(dict(cf_role=None, start_index=None, src_dim=None)) + metadata = ConnectivityMetadata(**self.args) + self.item.metadata = metadata + expected = metadata._asdict() + del expected["cf_role"] + del expected["start_index"] + del expected["src_dim"] + self.assertEqual(self.item._metadata_manager.values, expected) + self.assertIsNot( + self.item._metadata_manager.attributes, metadata.attributes + ) + def test_class_coordmetadata(self): self.args.update(dict(coord_system=None, climatological=False)) metadata = CoordMetadata(**self.args) diff --git a/lib/iris/tests/unit/coords/test_Connectivity.py b/lib/iris/tests/unit/coords/test_Connectivity.py new file mode 100644 index 0000000000..826d52cafe --- /dev/null +++ b/lib/iris/tests/unit/coords/test_Connectivity.py @@ -0,0 +1,350 @@ +# Copyright Iris contributors +# +# This file is part of Iris and is released under the LGPL license. +# See COPYING and COPYING.LESSER in the root of the repository for full +# licensing details. +"""Unit tests for the :class:`iris.coords.Connectivity` class.""" + +# Import iris.tests first so that some things can be initialised before +# importing anything else. +import iris.tests as tests + +from xml.dom import minidom + +import numpy as np +from numpy import ma + +from iris.coords import Connectivity +from iris._lazy_data import as_lazy_data, is_lazy_data + + +class TestStandard(tests.IrisTest): + def setUp(self): + # Crete an instance, with non-default arguments to allow testing of + # correct property setting. + self.kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((3, -1)), + "cf_role": "face_node_connectivity", + "long_name": "my_face_nodes", + "var_name": "face_nodes", + "attributes": {"notes": "this is a test"}, + "start_index": 1, + "src_dim": 1, + } + self.connectivity = Connectivity(**self.kwargs) + + def test_cf_role(self): + self.assertEqual(self.kwargs["cf_role"], self.connectivity.cf_role) + + def test_src_location(self): + expected = self.kwargs["cf_role"].split("_")[0] + self.assertEqual(expected, self.connectivity.src_location) + + def test_tgt_location(self): + expected = self.kwargs["cf_role"].split("_")[1] + self.assertEqual(expected, self.connectivity.tgt_location) + + def test_start_index(self): + self.assertEqual( + self.kwargs["start_index"], self.connectivity.start_index + ) + + def test_src_dim(self): + self.assertEqual(self.kwargs["src_dim"], self.connectivity.src_dim) + + def test_indices(self): + self.assertArrayEqual( + self.kwargs["indices"], self.connectivity.indices + ) + + def test_read_only(self): + attributes = ("indices", "cf_role", "start_index", "src_dim") + for attribute in attributes: + self.assertRaisesRegex( + AttributeError, + "can't set attribute", + setattr, + self.connectivity, + attribute, + 1, + ) + + def test_transpose(self): + expected_dim = 1 - self.kwargs["src_dim"] + expected_indices = self.kwargs["indices"].transpose() + new_connectivity = self.connectivity.transpose() + self.assertEqual(expected_dim, new_connectivity.src_dim) + self.assertArrayEqual(expected_indices, new_connectivity.indices) + + def test_lazy_indices(self): + self.assertTrue(is_lazy_data(self.connectivity.lazy_indices())) + + def test_core_indices(self): + self.assertArrayEqual( + self.kwargs["indices"], self.connectivity.core_indices() + ) + + def test_has_lazy_indices(self): + self.assertFalse(self.connectivity.has_lazy_indices()) + + def test_lazy_src_lengths(self): + self.assertTrue(is_lazy_data(self.connectivity.lazy_src_lengths())) + + def test_src_lengths(self): + expected = [3, 3, 3] + self.assertArrayEqual(expected, self.connectivity.src_lengths()) + + def test___str__(self): + expected = ( + "Connectivity(array([[1, 2, 3],\n [4, 5, 6],\n [7, 8, 9]]), " + "standard_name=None, units=Unit('unknown'), " + "long_name='my_face_nodes', var_name='face_nodes', " + "attributes={'notes': 'this is a test'})" + ) + self.assertEqual(expected, self.connectivity.__str__()) + + def test___repr__(self): + expected = ( + "Connectivity(array([[1, 2, 3],\n [4, 5, 6],\n [7, 8, 9]]), " + "standard_name=None, units=Unit('unknown'), " + "long_name='my_face_nodes', var_name='face_nodes', " + "attributes={'notes': 'this is a test'})" + ) + self.assertEqual(expected, self.connectivity.__repr__()) + + def test_xml_element(self): + doc = minidom.Document() + connectivity_element = self.connectivity.xml_element(doc) + self.assertEqual(connectivity_element.tagName, "connectivity") + for attribute in ("cf_role", "start_index", "src_dim"): + self.assertIn(attribute, connectivity_element.attributes) + + def test___eq__(self): + equivalent_kwargs = self.kwargs + equivalent_kwargs["indices"] = self.kwargs["indices"].transpose() + equivalent_kwargs["src_dim"] = 1 - self.kwargs["src_dim"] + equivalent = Connectivity(**equivalent_kwargs) + self.assertFalse( + (equivalent.indices == self.connectivity.indices).all() + ) + self.assertEqual(equivalent, self.connectivity) + + def test_different(self): + different_kwargs = self.kwargs + different_kwargs["indices"] = self.kwargs["indices"].transpose() + different = Connectivity(**different_kwargs) + self.assertNotEqual(different, self.connectivity) + + def test_no_cube_dims(self): + self.assertRaises(NotImplementedError, self.connectivity.cube_dims, 1) + + def test_shape(self): + self.assertEqual(self.kwargs["indices"].shape, self.connectivity.shape) + + def test_ndim(self): + self.assertEqual(self.kwargs["indices"].ndim, self.connectivity.ndim) + + def test___getitem_(self): + subset = self.connectivity[:, 0:1] + self.assertArrayEqual(self.kwargs["indices"][:, 0:1], subset.indices) + + def test_copy(self): + new_indices = np.linspace(11, 16, 6, dtype=int).reshape((3, -1)) + copy_connectivity = self.connectivity.copy(new_indices) + self.assertArrayEqual(new_indices, copy_connectivity.indices) + + def test_indices_by_src(self): + expected = self.kwargs["indices"].transpose() + self.assertArrayEqual(expected, self.connectivity.indices_by_src()) + + def test_indices_by_src_input(self): + expected = as_lazy_data(self.kwargs["indices"].transpose()) + by_src = self.connectivity.indices_by_src( + self.connectivity.lazy_indices() + ) + self.assertArrayEqual(expected, by_src) + + +class TestAltIndices(tests.IrisTest): + def setUp(self): + mask = ([0, 0, 0, 0, 1] * 2) + [0, 0, 0, 1, 1] + data = np.linspace(1, 15, 15, dtype=int).reshape((-1, 5)) + self.masked_indices = ma.array(data=data, mask=mask) + self.lazy_indices = as_lazy_data(data) + + def common(self, indices): + connectivity = Connectivity( + indices=indices, cf_role="face_node_connectivity" + ) + self.assertArrayEqual(indices, connectivity.indices) + + def test_int32(self): + indices = np.linspace(1, 9, 9, dtype=np.int32).reshape((-1, 3)) + self.common(indices) + + def test_uint32(self): + indices = np.linspace(1, 9, 9, dtype=np.uint32).reshape((-1, 3)) + self.common(indices) + + def test_lazy(self): + self.common(self.lazy_indices) + + def test_masked(self): + self.common(self.masked_indices) + + def test_masked_lazy(self): + self.common(as_lazy_data(self.masked_indices)) + + def test_has_lazy_indices(self): + connectivity = Connectivity( + indices=self.lazy_indices, cf_role="face_node_connectivity" + ) + self.assertTrue(connectivity.has_lazy_indices()) + + +class TestValidations(tests.IrisTest): + def test_start_index(self): + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((-1, 3)), + "cf_role": "face_node_connectivity", + "start_index": 2, + } + self.assertRaisesRegex( + ValueError, "Invalid start_index .", Connectivity, **kwargs + ) + + def test_src_dim(self): + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((-1, 3)), + "cf_role": "face_node_connectivity", + "src_dim": 2, + } + self.assertRaisesRegex( + ValueError, "Invalid src_dim .", Connectivity, **kwargs + ) + + def test_cf_role(self): + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((-1, 3)), + "cf_role": "error", + } + self.assertRaisesRegex( + ValueError, "Invalid cf_role .", Connectivity, **kwargs + ) + + def test_indices_int(self): + kwargs = { + "indices": np.linspace(1, 9, 9).reshape((-1, 3)), + "cf_role": "face_node_connectivity", + } + self.assertRaisesRegex( + ValueError, + "dtype must be numpy integer subtype", + Connectivity, + **kwargs, + ) + + def test_indices_start_index(self): + kwargs = { + "indices": np.linspace(-9, -1, 9, dtype=int).reshape((-1, 3)), + "cf_role": "face_node_connectivity", + } + self.assertRaisesRegex( + ValueError, " < start_index", Connectivity, **kwargs + ) + + def test_indices_dims_low(self): + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int), + "cf_role": "face_node_connectivity", + } + self.assertRaisesRegex( + ValueError, "Expected 2-dimensional shape,", Connectivity, **kwargs + ) + + def test_indices_dims_high(self): + kwargs = { + "indices": np.linspace(1, 12, 12, dtype=int).reshape((-1, 3, 2)), + "cf_role": "face_node_connectivity", + } + self.assertRaisesRegex( + ValueError, "Expected 2-dimensional shape,", Connectivity, **kwargs + ) + + def test_indices_locations_edge(self): + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((-1, 3)), + "cf_role": "edge_node_connectivity", + } + self.assertRaisesRegex( + ValueError, + "Not all src_locations meet requirement: len=2", + Connectivity, + **kwargs, + ) + + def test_indices_locations_face(self): + kwargs = { + "indices": np.linspace(1, 6, 6, dtype=int).reshape((-1, 2)), + "cf_role": "face_node_connectivity", + } + self.assertRaisesRegex( + ValueError, + "Not all src_locations meet requirement: len>=3", + Connectivity, + **kwargs, + ) + + def test_indices_locations_volume_face(self): + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((-1, 3)), + "cf_role": "volume_face_connectivity", + } + self.assertRaisesRegex( + ValueError, + "Not all src_locations meet requirement: len>=4", + Connectivity, + **kwargs, + ) + + def test_indices_locations_volume_edge(self): + kwargs = { + "indices": np.linspace(1, 12, 12, dtype=int).reshape((-1, 3)), + "cf_role": "volume_edge_connectivity", + } + self.assertRaisesRegex( + ValueError, + "Not all src_locations meet requirement: len>=6", + Connectivity, + **kwargs, + ) + + def test_indices_locations_alt_dim(self): + """The transposed equivalent of `test_indices_locations_volume_face`.""" + kwargs = { + "indices": np.linspace(1, 9, 9, dtype=int).reshape((3, -1)), + "cf_role": "volume_face_connectivity", + "src_dim": 1, + } + self.assertRaisesRegex( + ValueError, + "Not all src_locations meet requirement: len>=4", + Connectivity, + **kwargs, + ) + + def test_indices_locations_masked(self): + mask = ([0, 0, 0] * 2) + [0, 0, 1] + data = np.linspace(1, 9, 9, dtype=int).reshape((3, -1)) + kwargs = { + "indices": ma.array(data=data, mask=mask), + "cf_role": "face_node_connectivity", + } + # Validation of individual location sizes (denoted by masks) only + # available through explicit call of Connectivity.validate_indices(). + connectivity = Connectivity(**kwargs) + self.assertRaisesRegex( + ValueError, + "Not all src_locations meet requirement: len>=3", + connectivity.validate_indices, + )