From 673ccc352d6cda6e047f8b837266858eca162d27 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 18 Aug 2019 14:58:16 -0700 Subject: [PATCH 1/5] Explicitly keep track of indexes in merge.py --- xarray/core/alignment.py | 26 +-- xarray/core/computation.py | 39 ++-- xarray/core/coordinates.py | 91 ++++++---- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 53 +++--- xarray/core/merge.py | 354 +++++++++++++++++++++---------------- 6 files changed, 312 insertions(+), 253 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index bb44f48fb9b..d6e85a5548e 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -139,7 +139,7 @@ def align( all_indexes[dim].append(index) if join == "override": - objects = _override_indexes(list(objects), all_indexes, exclude) + objects = _override_indexes(objects, all_indexes, exclude) # We don't reindex over dimensions with all equal indexes for two reasons: # - It's faster for the usual case (already aligned objects). @@ -236,26 +236,27 @@ def is_alignable(obj): targets = [] no_key = object() not_replaced = object() - for n, variables in enumerate(objects): + for position, variables in enumerate(objects): if is_alignable(variables): - positions.append(n) + positions.append(position) keys.append(no_key) targets.append(variables) out.append(not_replaced) elif is_dict_like(variables): + current_out = OrderedDict() for k, v in variables.items(): - if is_alignable(v) and k not in indexes: - # Skip variables in indexes for alignment, because these - # should to be overwritten instead: - # https://github.com/pydata/xarray/issues/725 - positions.append(n) + if is_alignable(v): + positions.append(position) keys.append(k) targets.append(v) - out.append(OrderedDict(variables)) + current_out[k] = not_replaced + else: + current_out[k] = v + out.append(current_out) elif raise_on_invalid: raise ValueError( "object to align is neither an xarray.Dataset, " - "an xarray.DataArray nor a dictionary: %r" % variables + "an xarray.DataArray nor a dictionary: {!r}".format(variables) ) else: out.append(variables) @@ -276,7 +277,10 @@ def is_alignable(obj): out[position][key] = aligned_obj # something went wrong: we should have replaced all sentinel values - assert all(arg is not not_replaced for arg in out) + for arg in out: + assert arg is not not_replaced + if is_dict_like(arg): + assert all(value is not not_replaced for value in arg.values()) return out diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cb3a0d5db7d..bf22ca625ef 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,12 +24,13 @@ from . import duck_array_ops, utils from .alignment import deep_align -from .merge import expand_and_merge_variables +from .merge import merge_coordinates_without_align from .pycompat import dask_array_type from .utils import is_dict_like from .variable import Variable if TYPE_CHECKING: + from .coordinates import Coordinates from .dataset import Dataset _DEFAULT_FROZEN_SET = frozenset() # type: frozenset @@ -144,17 +145,16 @@ def result_name(objects: list) -> Any: return name -def _get_coord_variables(args): - input_coords = [] +def _get_coords_list(args) -> List["Coordinates"]: + coords_list = [] for arg in args: try: coords = arg.coords except AttributeError: pass # skip this argument else: - coord_vars = getattr(coords, "variables", coords) - input_coords.append(coord_vars) - return input_coords + coords_list.append(coords) + return coords_list def build_output_coords( @@ -177,32 +177,29 @@ def build_output_coords( ------- OrderedDict of Variable objects with merged coordinates. """ - input_coords = _get_coord_variables(args) + coords_list = _get_coords_list(args) - if exclude_dims: - input_coords = [ - OrderedDict( - (k, v) for k, v in coord_vars.items() if exclude_dims.isdisjoint(v.dims) - ) - for coord_vars in input_coords - ] - - if len(input_coords) == 1: + if len(coords_list) == 1 and not exclude_dims: # we can skip the expensive merge - unpacked_input_coords, = input_coords - merged = OrderedDict(unpacked_input_coords) + unpacked_coords, = coords_list + merged_vars = OrderedDict(unpacked_coords.variables) else: - merged = expand_and_merge_variables(input_coords) + # TODO: save these merged indexes, instead of re-computing them later + merged_vars, unused_indexes = merge_coordinates_without_align( + coords_list, exclude_dims=exclude_dims + ) output_coords = [] for output_dims in signature.output_core_dims: dropped_dims = signature.all_input_core_dims - set(output_dims) if dropped_dims: filtered = OrderedDict( - (k, v) for k, v in merged.items() if dropped_dims.isdisjoint(v.dims) + (k, v) + for k, v in merged_vars.items() + if dropped_dims.isdisjoint(v.dims) ) else: - filtered = merged + filtered = merged_vars output_coords.append(filtered) return output_coords diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 562d30dd6c7..c4027792a7e 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -4,12 +4,13 @@ TYPE_CHECKING, Any, Hashable, - Mapping, Iterator, + Mapping, + Optional, Union, Set, - Tuple, Sequence, + Tuple, cast, ) @@ -17,11 +18,7 @@ from . import formatting, indexing from .indexes import Indexes -from .merge import ( - expand_and_merge_variables, - merge_coords, - merge_coords_for_inplace_math, -) +from .merge import merge_coords, merge_coordinates_without_align from .utils import Frozen, ReprObject, either_dict_or_kwargs from .variable import Variable @@ -34,7 +31,7 @@ _THIS_ARRAY = ReprObject("") -class AbstractCoordinates(Mapping[Hashable, "DataArray"]): +class Coordinates(Mapping[Hashable, "DataArray"]): _data = None # type: Union["DataArray", "Dataset"] def __getitem__(self, key: Hashable) -> "DataArray": @@ -57,10 +54,10 @@ def indexes(self) -> Indexes: @property def variables(self): - raise NotImplementedError() + raise NotImplementedError - def _update_coords(self, coords): - raise NotImplementedError() + def _update_coords(self, coords, indexes): + raise NotImplementedError def __iter__(self) -> Iterator["Hashable"]: # needs to be in the same order as the dataset variables @@ -116,19 +113,19 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: def update(self, other: Mapping[Hashable, Any]) -> None: other_vars = getattr(other, "variables", other) - coords = merge_coords( + coords, indexes = merge_coords( [self.variables, other_vars], priority_arg=1, indexes=self.indexes ) - self._update_coords(coords) + self._update_coords(coords, indexes) def _merge_raw(self, other): """For use with binary arithmetic.""" if other is None: variables = OrderedDict(self.variables) + indexes = OrderedDict(self.indexes) else: - # don't align because we already called xarray.align - variables = expand_and_merge_variables([self.variables, other.variables]) - return variables + variables, indexes = merge_coordinates_without_align([self, other]) + return variables, indexes @contextmanager def _merge_inplace(self, other): @@ -136,18 +133,18 @@ def _merge_inplace(self, other): if other is None: yield else: - # don't include indexes in priority_vars, because we didn't align - # first - priority_vars = OrderedDict( - kv for kv in self.variables.items() if kv[0] not in self.dims - ) - variables = merge_coords_for_inplace_math( - [self.variables, other.variables], priority_vars=priority_vars + # don't include indexes in prioritized, because we didn't align + # first and we want indexes to be checked + prioritized = { + k: (v, None) for k, v in self.variables.items() if k not in self.indexes + } + variables, indexes = merge_coordinates_without_align( + [self, other], prioritized ) yield - self._update_coords(variables) + self._update_coords(variables, indexes) - def merge(self, other: "AbstractCoordinates") -> "Dataset": + def merge(self, other: "Coordinates") -> "Dataset": """Merge two sets of coordinates to create a new Dataset The method implements the logic used for joining coordinates in the @@ -173,13 +170,19 @@ def merge(self, other: "AbstractCoordinates") -> "Dataset": if other is None: return self.to_dataset() - else: - other_vars = getattr(other, "variables", other) - coords = expand_and_merge_variables([self.variables, other_vars]) - return Dataset._from_vars_and_coord_names(coords, set(coords)) + if not isinstance(other, Coordinates): + other = Dataset(coords=other).coords + + coords, indexes = merge_coordinates_without_align([self, other]) + coord_names = set(coords) + merged = Dataset._construct_direct( + variables=coords, coord_names=coord_names, indexes=indexes + ) + return merged -class DatasetCoordinates(AbstractCoordinates): + +class DatasetCoordinates(Coordinates): """Dictionary like container for Dataset coordinates. Essentially an immutable OrderedDict with keys given by the array's @@ -218,7 +221,11 @@ def to_dataset(self) -> "Dataset": """ return self._data._copy_listed(self._names) - def _update_coords(self, coords: Mapping[Hashable, Any]) -> None: + def _update_coords( + self, + coords: "OrderedDict[Hashable, Variable]", + indexes: Mapping[Hashable, pd.Index], + ) -> None: from .dataset import calculate_dimensions variables = self._data._variables.copy() @@ -234,7 +241,12 @@ def _update_coords(self, coords: Mapping[Hashable, Any]) -> None: self._data._variables = variables self._data._coord_names.update(new_coord_names) self._data._dims = dims - self._data._indexes = None + + # TODO(shoyer): once ._indexes is always populated by a dict, modify + # it to update inplace instead. + original_indexes = OrderedDict(self._data.indexes) + original_indexes.update(indexes) + self._data._indexes = original_indexes def __delitem__(self, key: Hashable) -> None: if key in self: @@ -251,7 +263,7 @@ def _ipython_key_completions_(self): ] -class DataArrayCoordinates(AbstractCoordinates): +class DataArrayCoordinates(Coordinates): """Dictionary like container for DataArray coordinates. Essentially an OrderedDict with keys given by the array's @@ -274,7 +286,11 @@ def _names(self) -> Set[Hashable]: def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) - def _update_coords(self, coords) -> None: + def _update_coords( + self, + coords: "OrderedDict[Hashable, Variable]", + indexes: Mapping[Hashable, pd.Index], + ) -> None: from .dataset import calculate_dimensions coords_plus_data = coords.copy() @@ -285,7 +301,12 @@ def _update_coords(self, coords) -> None: "cannot add coordinates with new dimensions to " "a DataArray" ) self._data._coords = coords - self._data._indexes = None + + # TODO(shoyer): once ._indexes is always populated by a dict, modify + # it to update inplace instead. + original_indexes = OrderedDict(self._data.indexes) + original_indexes.update(indexes) + self._data._indexes = original_indexes @property def variables(self): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b311a847790..1c46b50de9c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2453,7 +2453,7 @@ def func(self, other): if not reflexive else f(other_variable, self.variable) ) - coords = self.coords._merge_raw(other_coords) + coords, indexes = self.coords._merge_raw(other_coords) name = self._result_name(other) return self._replace(variable, coords, name) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d25f24c1e28..b71851f4390 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -65,7 +65,7 @@ dataset_merge_method, dataset_update_method, merge_data_and_coords, - merge_variables, + merge_coordinates_without_align, ) from .options import OPTIONS, _get_keep_attrs from .pycompat import dask_array_type @@ -483,10 +483,9 @@ def __init__( data_vars = {} if coords is None: coords = {} - self._set_init_vars_and_dims(data_vars, coords, compat) # TODO(shoyer): expose indexes as a public argument in __init__ - self._indexes = None # type: Optional[OrderedDict[Any, pd.Index]] + self._set_init_vars_and_dims(data_vars, coords, compat) if attrs is not None: self._attrs = OrderedDict(attrs) @@ -507,13 +506,14 @@ def _set_init_vars_and_dims(self, data_vars, coords, compat): if isinstance(coords, Dataset): coords = coords.variables - variables, coord_names, dims = merge_data_and_coords( + variables, coord_names, dims, indexes = merge_data_and_coords( data_vars, coords, compat=compat ) self._variables = variables self._coord_names = coord_names self._dims = dims + self._indexes = indexes @classmethod def load_store(cls, store, decoder=None) -> "Dataset": @@ -814,7 +814,7 @@ def _construct_direct( cls, variables, coord_names, - dims, + dims=None, attrs=None, indexes=None, encoding=None, @@ -823,6 +823,8 @@ def _construct_direct( """Shortcut around __init__ for internal use when we want to skip costly validation """ + if dims is None: + dims = calculate_dimensions(variables) obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -838,8 +840,7 @@ def _construct_direct( @classmethod def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None): - dims = calculate_dimensions(variables) - return cls._construct_direct(variables, coord_names, dims, attrs) + return cls._construct_direct(variables, coord_names, attrs=attrs) # TODO(shoyer): renable type checking on this signature when pytype has a # good way to handle defaulting arguments to a sentinel value: @@ -1243,6 +1244,8 @@ def __delitem__(self, key: Hashable) -> None: """ del self._variables[key] self._coord_names.discard(key) + if key in self.indexes: + del self._indexes[key] self._dims = calculate_dimensions(self._variables) # mutable objects should not be hashable @@ -1785,20 +1788,16 @@ def _validate_indexers( return indexers_list def _get_indexers_coords_and_indexes(self, indexers): - """ Extract coordinates from indexers. - Returns an OrderedDict mapping from coordinate name to the - coordinate variable. + """Extract coordinates and indexes from indexers. Only coordinate with a name different from any of self.variables will be attached. """ from .dataarray import DataArray - coord_list = [] - indexes = OrderedDict() + coords_list = [] for k, v in indexers.items(): if isinstance(v, DataArray): - v_coords = v.coords if v.dtype.kind == "b": if v.ndim != 1: # we only support 1-d boolean array raise ValueError( @@ -1809,14 +1808,14 @@ def _get_indexers_coords_and_indexes(self, indexers): # Make sure in case of boolean DataArray, its # coordinate also should be indexed. v_coords = v[v.values.nonzero()[0]].coords - - coord_list.append({d: v_coords[d].variable for d in v.coords}) - indexes.update(v.indexes) + else: + v_coords = v.coords + coords_list.append(v_coords) # we don't need to call align() explicitly or check indexes for # alignment, because merge_variables already checks for exact alignment # between dimension coordinates - coords = merge_variables(coord_list) + coords, indexes = merge_coordinates_without_align(coords_list) assert_coordinate_consistent(self, coords) # silently drop the conflicted variables. @@ -2552,12 +2551,14 @@ def _rename_vars(self, name_dict, dims_dict): def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} - def _rename_indexes(self, name_dict): + def _rename_indexes(self, name_dict, dims_set): if self._indexes is None: return None indexes = OrderedDict() for k, v in self.indexes.items(): new_name = name_dict.get(k, k) + if new_name not in dims_set: + continue if isinstance(v, pd.MultiIndex): new_names = [name_dict.get(k, k) for k in v.names] index = pd.MultiIndex( @@ -2575,7 +2576,7 @@ def _rename_indexes(self, name_dict): def _rename_all(self, name_dict, dims_dict): variables, coord_names = self._rename_vars(name_dict, dims_dict) dims = self._rename_dims(dims_dict) - indexes = self._rename_indexes(name_dict) + indexes = self._rename_indexes(name_dict, dims.keys()) return variables, coord_names, dims, indexes def rename( @@ -3359,11 +3360,8 @@ def update(self, other: "DatasetLike", inplace: bool = None) -> "Dataset": dataset. """ inplace = _check_inplace(inplace, default=True) - variables, coord_names, dims = dataset_update_method(self, other) - - return self._replace_vars_and_dims( - variables, coord_names, dims, inplace=inplace - ) + merge_result = dataset_update_method(self, other) + return self._replace(inplace=inplace, **merge_result._asdict()) def merge( self, @@ -3425,7 +3423,7 @@ def merge( If any variables conflict (see ``compat``). """ inplace = _check_inplace(inplace) - variables, coord_names, dims = dataset_merge_method( + merge_result = dataset_merge_method( self, other, overwrite_vars=overwrite_vars, @@ -3433,10 +3431,7 @@ def merge( join=join, fill_value=fill_value, ) - - return self._replace_vars_and_dims( - variables, coord_names, dims, inplace=inplace - ) + return self._replace(inplace=inplace, **merge_result._asdict()) def _assert_all_in_dataset( self, names: Iterable[Hashable], virtual_okay: bool = False diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 882667dbaaa..f73aed86724 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,5 +1,7 @@ from collections import OrderedDict +import functools from typing import ( + AbstractSet, Any, Dict, Hashable, @@ -7,9 +9,12 @@ List, Mapping, MutableMapping, + NamedTuple, + Optional, Sequence, Set, Tuple, + TypeVar, Union, TYPE_CHECKING, ) @@ -18,10 +23,11 @@ from . import dtypes, pdcompat from .alignment import deep_align -from .utils import Frozen +from .utils import Frozen, dict_equiv from .variable import Variable, as_variable, assert_unique_multiindex_level_names if TYPE_CHECKING: + from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset @@ -29,9 +35,8 @@ DataArray, Variable, Tuple[Hashable, Any], Tuple[Sequence[Hashable], Any] ] DatasetLike = Union[Dataset, Mapping[Hashable, DatasetLikeValue]] - """Any object type that can be used on the rhs of Dataset.update, - Dataset.merge, etc. - """ + # Any object type that can be used on the rhs of Dataset.update, + # Dataset.merge, etc. MutableDatasetLike = Union[Dataset, MutableMapping[Hashable, DatasetLikeValue]] @@ -70,8 +75,9 @@ class MergeError(ValueError): # TODO: move this to an xarray.exceptions module? -def unique_variable(name, variables, compat="broadcast_equals"): - # type: (Any, List[Variable], str) -> Variable +def unique_variable( + name: Hashable, variables: List[Variable], compat: str = "broadcast_equals" +) -> Variable: """Return the unique variable from a list of variables or raise MergeError. Parameters @@ -126,138 +132,193 @@ def _assert_compat_valid(compat): raise ValueError("compat=%r invalid: must be %s" % (compat, set(_VALID_COMPAT))) -class OrderedDefaultDict(OrderedDict): - # minimal version of an ordered defaultdict - # beware: does not pickle or copy properly - def __init__(self, default_factory): - self.default_factory = default_factory - super().__init__() - - def __missing__(self, key): - self[key] = default = self.default_factory() - return default +MergeElement = Tuple[Variable, Optional[pd.Index]] -def merge_variables( - list_of_variables_dicts: List[Mapping[Any, Variable]], - priority_vars: Mapping[Any, Variable] = None, +def merge_collected( + grouped: "OrderedDict[Hashable, List[MergeElement]]", + prioritized: Mapping[Hashable, MergeElement] = None, compat: str = "minimal", -) -> "OrderedDict[Any, Variable]": +) -> Tuple["OrderedDict[Hashable, Variable]", "OrderedDict[Hashable, pd.Index]"]: """Merge dicts of variables, while resolving conflicts appropriately. Parameters ---------- - lists_of_variables_dicts : list of mappings with Variable values - List of mappings for which each value is a xarray.Variable object. - priority_vars : mapping with Variable or None values, optional - If provided, variables are always taken from this dict in preference to - the input variable dictionaries, without checking for conflicts. + grouped : mapping + Lists of variables of the same name to merge. + prioritized : mapping, optional + If provided, variables/indexes are always taken from this dict in + preference to the grouped dictionaries, without checking for conflicts. compat : {'identical', 'equals', 'broadcast_equals', 'minimal', 'no_conflicts'}, optional Type of equality check to use when checking for conflicts. Returns ------- - OrderedDict with keys taken by the union of keys on list_of_variable_dicts, + OrderedDict with keys taken by the union of keys on list_of_mappings, and Variable values corresponding to those that should be found on the merged result. """ # noqa - if priority_vars is None: - priority_vars = {} + if prioritized is None: + prioritized = {} _assert_compat_valid(compat) - dim_compat = min(compat, "equals", key=_VALID_COMPAT.get) - - lookup = OrderedDefaultDict(list) - for variables in list_of_variables_dicts: - for name, var in variables.items(): - lookup[name].append(var) - - # n.b. it's important to fill up merged in the original order in which - # variables appear - merged = OrderedDict() # type: OrderedDict[Any, Variable] - - for name, var_list in lookup.items(): - if name in priority_vars: - # one of these arguments (e.g., the first for in-place arithmetic - # or the second for Dataset.update) takes priority - merged[name] = priority_vars[name] + + merged_vars = OrderedDict() # type: OrderedDict[Any, Variable] + merged_indexes = OrderedDict() # type: OrderedDict[Any, pd.Index] + + for name, elements_list in grouped.items(): + if name in prioritized: + variable, index = prioritized[name] + merged_vars[name] = variable + if index is not None: + merged_indexes[name] = index else: - dim_variables = [var for var in var_list if (name,) == var.dims] - if dim_variables: - # if there are dimension coordinates, these must be equal (or - # identical), and they take priority over non-dimension - # coordinates - merged[name] = unique_variable(name, dim_variables, dim_compat) + indexed_elements = [ + (variable, index) + for variable, index in elements_list + if index is not None + ] + + if indexed_elements: + # TODO(shoyer): consider adjusting this logic. Are we really + # OK throwing away variable without an index in favor of + # indexed variables, without even checking if values match? + variable, index = indexed_elements[0] + for _, other_index in indexed_elements[1:]: + if not index.equals(other_index): + raise MergeError( + "conflicting values for index %r on objects to be " + "combined:\nfirst value: %r\nsecond value: %r" + % (name, index, other_index) + ) + if compat == "identical": + for other_variable, _ in indexed_elements[1:]: + if not dict_equiv(variable.attrs, other_variable.attrs): + raise MergeError( + "conflicting attribute values on combined " + "variable %r:\nfirst value: %r\nsecond value: %r" + % (name, variable.attrs, other_variable.attrs) + ) + merged_vars[name] = variable + merged_indexes[name] = index else: + variables = [variable for variable, _ in elements_list] try: - merged[name] = unique_variable(name, var_list, compat) + merged_vars[name] = unique_variable(name, variables, compat) except MergeError: if compat != "minimal": # we need more than "minimal" compatibility (for which # we drop conflicting coordinates) raise - return merged - + return merged_vars, merged_indexes -def expand_variable_dicts( - list_of_variable_dicts: "List[Union[Dataset, OrderedDict]]", -) -> "List[Mapping[Any, Variable]]": - """Given a list of dicts with xarray object values, expand the values. - Parameters - ---------- - list_of_variable_dicts : list of dict or Dataset objects - Each value for the mappings must be of the following types: - - an xarray.Variable - - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in - an xarray.Variable - - or an xarray.DataArray +def collect_variables_and_indexes( + list_of_mappings: "List[Union[Dataset, OrderedDict]]", +) -> "OrderedDict[Hashable, List[MergeElement]]": + """Collect variables and indexes from list of mappings of xarray objects. - Returns - ------- - A list of ordered dictionaries corresponding to inputs, or coordinates from - an input's values. The values of each ordered dictionary are all - xarray.Variable objects. + Mappings must either be Dataset objects, or have values of one of the + following types: + - an xarray.Variable + - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in + an xarray.Variable + - or an xarray.DataArray """ from .dataarray import DataArray # noqa: F811 from .dataset import Dataset - var_dicts = [] + grouped = ( + OrderedDict() + ) # type: OrderedDict[Hashable, List[Tuple[Variable, pd.Index]]] - for variables in list_of_variable_dicts: - if isinstance(variables, Dataset): - var_dicts.append(variables.variables) - continue + def append(name, variable, index): + values = grouped.setdefault(name, []) + values.append((variable, index)) - # append coords to var_dicts before appending sanitized_vars, - # because we want coords to appear first - sanitized_vars = OrderedDict() # type: OrderedDict[Any, Variable] + def append_all(variables, indexes): + for name, variable in variables.items(): + append(name, variable, indexes.get(name)) - for name, var in variables.items(): - if isinstance(var, DataArray): - # use private API for speed - coords = var._coords.copy() + for mapping in list_of_mappings: + if isinstance(mapping, Dataset): + append_all(mapping.variables, mapping.indexes) + continue + + for name, variable in mapping.items(): + if isinstance(variable, DataArray): + coords = variable._coords.copy() # use private API for speed + indexes = OrderedDict(variable.indexes) # explicitly overwritten variables should take precedence coords.pop(name, None) - var_dicts.append(coords) - - var = as_variable(var, name=name) - sanitized_vars[name] = var + indexes.pop(name, None) + append_all(coords, indexes) - var_dicts.append(sanitized_vars) + variable = as_variable(variable, name=name) + if variable.dims == (name,): + variable = variable.to_index_variable() + index = variable.to_index() + else: + index = None + append(name, variable, index) + + return grouped + + +def collect_from_coordinates( + list_of_coords: "List[Coordinates]" +) -> "OrderedDict[Hashable, List[MergeElement]]": + """Collect variables and indexes to be merged from Coordinate objects.""" + grouped = ( + OrderedDict() + ) # type: OrderedDict[Hashable, List[Tuple[Variable, pd.Index]]] + + for coords in list_of_coords: + variables = coords.variables + indexes = coords.indexes + for name, variable in variables.items(): + value = grouped.setdefault(name, []) + value.append((variable, indexes.get(name))) + return grouped + + +def merge_coordinates_without_align( + objects: "List[Coordinates]", + prioritized: Mapping[Hashable, MergeElement] = None, + exclude_dims: AbstractSet = frozenset(), +) -> Tuple["OrderedDict[Hashable, Variable]", "OrderedDict[Hashable, pd.Index]"]: + """Merge variables/indexes from coordinates without automatic alignments. + + This function is used for merging coordinate from pre-existing xarray + objects. + """ + collected = collect_from_coordinates(objects) + + if exclude_dims: + filtered = OrderedDict() # type: OrderedDict[Hashable, List[MergeElement]] + for name, elements in collected.items(): + new_elements = [ + (variable, index) + for variable, index in elements + if exclude_dims.isdisjoint(variable.dims) + ] + if new_elements: + filtered[name] = new_elements + else: + filtered = collected - return var_dicts + return merge_collected(filtered, prioritized) def determine_coords( - list_of_variable_dicts: Iterable["DatasetLike"] + list_of_mappings: Iterable["DatasetLike"] ) -> Tuple[Set[Hashable], Set[Hashable]]: """Given a list of dicts with xarray object values, identify coordinates. Parameters ---------- - list_of_variable_dicts : list of dict or Dataset objects + list_of_mappings : list of dict or Dataset objects Of the same form as the arguments to expand_variable_dicts. Returns @@ -273,7 +334,7 @@ def determine_coords( coord_names = set() # type: set noncoord_names = set() # type: set - for variables in list_of_variable_dicts: + for variables in list_of_mappings: if isinstance(variables, Dataset): coord_names.update(variables.coords) noncoord_names.update(variables.data_vars) @@ -321,18 +382,7 @@ def coerce_pandas_values(objects: Iterable["DatasetLike"]) -> List["DatasetLike" return out -def merge_coords_for_inplace_math(objs, priority_vars=None): - """Merge coordinate variables without worrying about alignment. - - This function is used for merging variables in coordinates.py. - """ - expanded = expand_variable_dicts(objs) - variables = merge_variables(expanded, priority_vars) - assert_unique_multiindex_level_names(variables) - return variables - - -def _get_priority_vars(objects, priority_arg, compat="equals"): +def _get_priority_vars_and_indexes(objects, priority_arg, compat="equals"): """Extract the priority variable from a list of mappings. We need this method because in some cases the priority argument itself @@ -354,32 +404,24 @@ def _get_priority_vars(objects, priority_arg, compat="equals"): values indicating priority variables. """ # noqa if priority_arg is None: - priority_vars = {} - else: - expanded = expand_variable_dicts([objects[priority_arg]]) - priority_vars = merge_variables(expanded, compat=compat) - return priority_vars - + return OrderedDict() -def expand_and_merge_variables(objs, priority_arg=None): - """Merge coordinate variables without worrying about alignment. - - This function is used for merging variables in computation.py. - """ - expanded = expand_variable_dicts(objs) - priority_vars = _get_priority_vars(objs, priority_arg) - variables = merge_variables(expanded, priority_vars) - return variables + collected = collect_variables_and_indexes([objects[priority_arg]]) + variables, indexes = merge_collected(collected, compat=compat) + grouped = OrderedDict() + for name, variable in variables.items(): + grouped[name] = (variable, indexes.get(name)) + return grouped def merge_coords( - objs, + objects, compat="minimal", join="outer", priority_arg=None, indexes=None, fill_value=dtypes.NA, -): +) -> Tuple["OrderedDict[Hashable, Variable]", "OrderedDict[Hashable, pd.Index]"]: """Merge coordinate variables. See merge_core below for argument descriptions. This works similarly to @@ -387,29 +429,28 @@ def merge_coords( coordinates or not. """ _assert_compat_valid(compat) - coerced = coerce_pandas_values(objs) + coerced = coerce_pandas_values(objects) aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - expanded = expand_variable_dicts(aligned) - priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat) - variables = merge_variables(expanded, priority_vars, compat=compat) + collected = collect_variables_and_indexes(aligned) + prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) + variables, out_indexes = merge_collected(collected, prioritized, compat=compat) assert_unique_multiindex_level_names(variables) - - return variables + return variables, out_indexes def merge_data_and_coords(data, coords, compat="broadcast_equals", join="outer"): """Used in Dataset.__init__.""" - objs = [data, coords] + objects = [data, coords] explicit_coords = coords.keys() - indexes = dict(extract_indexes(coords)) + indexes = dict(_extract_indexes_from_coords(coords)) return merge_core( - objs, compat, join, explicit_coords=explicit_coords, indexes=indexes + objects, compat, join, explicit_coords=explicit_coords, indexes=indexes ) -def extract_indexes(coords): +def _extract_indexes_from_coords(coords): """Yields the name & index of valid indexes from a mapping of coords""" for name, variable in coords.items(): variable = as_variable(variable, name=name) @@ -432,31 +473,42 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): ) +_MergeResult = NamedTuple( + "_MergeResult", + [ + ("variables", "OrderedDict[Hashable, Variable]"), + ("coord_names", Set[Hashable]), + ("dims", Dict[Hashable, int]), + ("indexes", "OrderedDict[Hashable, pd.Index]"), + ], +) + + def merge_core( - objs, + objects, compat="broadcast_equals", join="outer", priority_arg=None, explicit_coords=None, indexes=None, fill_value=dtypes.NA, -) -> Tuple["OrderedDict[Hashable, Variable]", Set[Hashable], Dict[Hashable, int]]: +) -> _MergeResult: """Core logic for merging labeled objects. This is not public API. Parameters ---------- - objs : list of mappings + objects : list of mappings All values must be convertable to labeled arrays. compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts'}, optional Compatibility checks to use when merging variables. join : {'outer', 'inner', 'left', 'right'}, optional How to combine objects with different indexes. priority_arg : integer, optional - Optional argument in `objs` that takes precedence over the others. + Optional argument in `objects` that takes precedence over the others. explicit_coords : set, optional - An explicit list of variables from `objs` that are coordinates. + An explicit list of variables from `objects` that are coordinates. indexes : dict, optional Dictionary with values given by pandas.Index objects. fill_value : scalar, optional @@ -479,28 +531,25 @@ def merge_core( _assert_compat_valid(compat) - coerced = coerce_pandas_values(objs) + coerced = coerce_pandas_values(objects) aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - expanded = expand_variable_dicts(aligned) - - coord_names, noncoord_names = determine_coords(coerced) + collected = collect_variables_and_indexes(aligned) - priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat) - variables = merge_variables(expanded, priority_vars, compat=compat) + prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) + variables, out_indexes = merge_collected(collected, prioritized, compat=compat) assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) + coord_names, noncoord_names = determine_coords(coerced) if explicit_coords is not None: assert_valid_explicit_coords(variables, dims, explicit_coords) coord_names.update(explicit_coords) - for dim, size in dims.items(): if dim in variables: coord_names.add(dim) - ambiguous_coords = coord_names.intersection(noncoord_names) if ambiguous_coords: raise MergeError( @@ -508,7 +557,7 @@ def merge_core( "coordinates or not in the merged result: %s" % ambiguous_coords ) - return variables, coord_names, dims + return _MergeResult(variables, coord_names, dims, out_indexes) def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): @@ -580,7 +629,7 @@ def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): dict_like_objects = list() for obj in objects: - if not (isinstance(obj, (DataArray, Dataset, dict))): + if not isinstance(obj, (DataArray, Dataset, dict)): raise TypeError( "objects must be an iterable containing only " "Dataset(s), DataArray(s), and dictionaries." @@ -589,12 +638,8 @@ def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): obj = obj.to_dataset() if isinstance(obj, DataArray) else obj dict_like_objects.append(obj) - variables, coord_names, dims = merge_core( - dict_like_objects, compat, join, fill_value=fill_value - ) - # TODO: don't always recompute indexes - merged = Dataset._construct_direct(variables, coord_names, dims, indexes=None) - + merge_result = merge_core(dict_like_objects, compat, join, fill_value=fill_value) + merged = Dataset._construct_direct(**merge_result._asdict()) return merged @@ -605,10 +650,9 @@ def dataset_merge_method( compat: str, join: str, fill_value: Any, -) -> Tuple["OrderedDict[Hashable, Variable]", Set[Hashable], Dict[Hashable, int]]: +) -> _MergeResult: """Guts of the Dataset.merge method. """ - # we are locked into supporting overwrite_vars for the Dataset.merge # method due for backwards compatibility # TODO: consider deprecating it? @@ -640,9 +684,7 @@ def dataset_merge_method( ) -def dataset_update_method( - dataset: "Dataset", other: "DatasetLike" -) -> Tuple["OrderedDict[Hashable, Variable]", Set[Hashable], Dict[Hashable, int]]: +def dataset_update_method(dataset: "Dataset", other: "DatasetLike") -> _MergeResult: """Guts of the Dataset.update method. This drops a duplicated coordinates from `other` if `other` is not an From 7d8312515dc80a1f0ba02a4e9a23004f6bd13a2b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 19 Aug 2019 23:29:46 -0700 Subject: [PATCH 2/5] Typing fixes --- xarray/core/merge.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f73aed86724..a04ba686b4b 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -31,13 +31,14 @@ from .dataarray import DataArray from .dataset import Dataset - DatasetLikeValue = Union[ - DataArray, Variable, Tuple[Hashable, Any], Tuple[Sequence[Hashable], Any] + DimsLike = Union[Hashable, Sequence[Hashable]] + VariableTuple = Union[ + Tuple[DimsLike, Any], + Tuple[DimsLike, Any, Mapping], + Tuple[DimsLike, Any, Mapping, Mapping], ] + DatasetLikeValue = Union[DataArray, Variable, VariableTuple] DatasetLike = Union[Dataset, Mapping[Hashable, DatasetLikeValue]] - # Any object type that can be used on the rhs of Dataset.update, - # Dataset.merge, etc. - MutableDatasetLike = Union[Dataset, MutableMapping[Hashable, DatasetLikeValue]] PANDAS_TYPES = (pd.Series, pd.DataFrame, pdcompat.Panel) @@ -215,7 +216,7 @@ def merge_collected( def collect_variables_and_indexes( - list_of_mappings: "List[Union[Dataset, OrderedDict]]", + list_of_mappings: "List[DatasetLike]", ) -> "OrderedDict[Hashable, List[MergeElement]]": """Collect variables and indexes from list of mappings of xarray objects. @@ -334,12 +335,12 @@ def determine_coords( coord_names = set() # type: set noncoord_names = set() # type: set - for variables in list_of_mappings: - if isinstance(variables, Dataset): - coord_names.update(variables.coords) - noncoord_names.update(variables.data_vars) + for mapping in list_of_mappings: + if isinstance(mapping, Dataset): + coord_names.update(mapping.coords) + noncoord_names.update(mapping.data_vars) else: - for name, var in variables.items(): + for name, var in mapping.items(): if isinstance(var, DataArray): coords = set(var._coords) # use private API for speed # explicitly overwritten variables should take precedence @@ -382,7 +383,9 @@ def coerce_pandas_values(objects: Iterable["DatasetLike"]) -> List["DatasetLike" return out -def _get_priority_vars_and_indexes(objects, priority_arg, compat="equals"): +def _get_priority_vars_and_indexes( + objects: List["DatasetLike"], priority_arg: Optional[int], compat: str = "equals" +) -> "OrderedDict[Hashable, MergeElement]": """Extract the priority variable from a list of mappings. We need this method because in some cases the priority argument itself @@ -400,15 +403,14 @@ def _get_priority_vars_and_indexes(objects, priority_arg, compat="equals"): Returns ------- - None, if priority_arg is None, or an OrderedDict with Variable objects as - values indicating priority variables. + An OrderedDict of variables and associated indexes (if any) to prioritize. """ # noqa if priority_arg is None: return OrderedDict() collected = collect_variables_and_indexes([objects[priority_arg]]) variables, indexes = merge_collected(collected, compat=compat) - grouped = OrderedDict() + grouped = OrderedDict() # type: OrderedDict[Hashable, MergeElement] for name, variable in variables.items(): grouped[name] = (variable, indexes.get(name)) return grouped @@ -669,8 +671,10 @@ def dataset_merge_method( objs = [dataset, other] priority_arg = 1 else: - other_overwrite = OrderedDict() # type: MutableDatasetLike - other_no_overwrite = OrderedDict() # type: MutableDatasetLike + other_overwrite = OrderedDict() # type: OrderedDict[Hashable, DatasetLikeValue] + other_no_overwrite = ( + OrderedDict() + ) # type: OrderedDict[Hashable, DatasetLikeValue] for k, v in other.items(): if k in overwrite_vars: other_overwrite[k] = v From e36050c2a7428601a61ad566c3b189e5db7de0b5 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 19 Aug 2019 23:42:14 -0700 Subject: [PATCH 3/5] More tying fixes --- xarray/core/dataset.py | 6 ++--- xarray/core/merge.py | 51 +++++++++++++++++++++++++----------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b71851f4390..12ff8eeb88e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -84,7 +84,7 @@ if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray - from .merge import DatasetLike + from .merge import CoercibleMapping try: from dask.delayed import Delayed @@ -3330,7 +3330,7 @@ def unstack(self, dim: Union[Hashable, Iterable[Hashable]] = None) -> "Dataset": result = result._unstack_once(dim) return result - def update(self, other: "DatasetLike", inplace: bool = None) -> "Dataset": + def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": """Update this dataset's variables with those from another dataset. Parameters @@ -3365,7 +3365,7 @@ def update(self, other: "DatasetLike", inplace: bool = None) -> "Dataset": def merge( self, - other: "DatasetLike", + other: "CoercibleMapping", inplace: bool = None, overwrite_vars: Union[Hashable, Iterable[Hashable]] = frozenset(), compat: str = "no_conflicts", diff --git a/xarray/core/merge.py b/xarray/core/merge.py index a04ba686b4b..598c5c2ce17 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -37,8 +37,10 @@ Tuple[DimsLike, Any, Mapping], Tuple[DimsLike, Any, Mapping, Mapping], ] - DatasetLikeValue = Union[DataArray, Variable, VariableTuple] - DatasetLike = Union[Dataset, Mapping[Hashable, DatasetLikeValue]] + XarrayValue = Union[DataArray, Variable, VariableTuple] + DatasetLike = Union[Dataset, Mapping[Hashable, XarrayValue]] + CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame] + CoercibleMapping = Union[Dataset, Mapping[Hashable, CoercibleValue]] PANDAS_TYPES = (pd.Series, pd.DataFrame, pdcompat.Panel) @@ -350,7 +352,7 @@ def determine_coords( return coord_names, noncoord_names -def coerce_pandas_values(objects: Iterable["DatasetLike"]) -> List["DatasetLike"]: +def coerce_pandas_values(objects: Iterable["CoercibleMapping"]) -> List["DatasetLike"]: """Convert pandas values found in a list of labeled objects. Parameters @@ -417,12 +419,12 @@ def _get_priority_vars_and_indexes( def merge_coords( - objects, - compat="minimal", - join="outer", - priority_arg=None, - indexes=None, - fill_value=dtypes.NA, + objects: Iterable["CoercibleMapping"], + compat: str = "minimal", + join: str = "outer", + priority_arg: Optional[int] = None, + indexes: Optional[Mapping[Hashable, pd.Index]] = None, + fill_value: object = dtypes.NA, ) -> Tuple["OrderedDict[Hashable, Variable]", "OrderedDict[Hashable, pd.Index]"]: """Merge coordinate variables. @@ -487,13 +489,13 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): def merge_core( - objects, - compat="broadcast_equals", - join="outer", - priority_arg=None, - explicit_coords=None, - indexes=None, - fill_value=dtypes.NA, + objects: Iterable["CoercibleMapping"], + compat: str = "broadcast_equals", + join: str = "outer", + priority_arg: Optional[int] = None, + explicit_coords: Optional[Sequence] = None, + indexes: Optional[Mapping[Hashable, pd.Index]] = None, + fill_value: object = dtypes.NA, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -562,7 +564,12 @@ def merge_core( return _MergeResult(variables, coord_names, dims, out_indexes) -def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): +def merge( + objects: Iterable[Union["DataArray", CoercibleMapping]], + compat: str = "no_conflicts", + join: str = "outer", + fill_value: object = dtypes.NA, +) -> "Dataset": """Merge any number of xarray objects into a single Dataset as variables. Parameters @@ -647,7 +654,7 @@ def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): def dataset_merge_method( dataset: "Dataset", - other: "DatasetLike", + other: "CoercibleMapping", overwrite_vars: Union[Hashable, Iterable[Hashable]], compat: str, join: str, @@ -671,10 +678,10 @@ def dataset_merge_method( objs = [dataset, other] priority_arg = 1 else: - other_overwrite = OrderedDict() # type: OrderedDict[Hashable, DatasetLikeValue] + other_overwrite = OrderedDict() # type: OrderedDict[Hashable, CoercibleValue] other_no_overwrite = ( OrderedDict() - ) # type: OrderedDict[Hashable, DatasetLikeValue] + ) # type: OrderedDict[Hashable, CoercibleValue] for k, v in other.items(): if k in overwrite_vars: other_overwrite[k] = v @@ -688,7 +695,9 @@ def dataset_merge_method( ) -def dataset_update_method(dataset: "Dataset", other: "DatasetLike") -> _MergeResult: +def dataset_update_method( + dataset: "Dataset", other: "CoercibleMapping" +) -> _MergeResult: """Guts of the Dataset.update method. This drops a duplicated coordinates from `other` if `other` is not an From 22804074fbe1757699c6d0c090e8e571686c32ef Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 20 Aug 2019 09:52:43 -0700 Subject: [PATCH 4/5] more typing fixes --- xarray/core/merge.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 598c5c2ce17..3251a07f6ba 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -32,12 +32,14 @@ from .dataset import Dataset DimsLike = Union[Hashable, Sequence[Hashable]] - VariableTuple = Union[ - Tuple[DimsLike, Any], - Tuple[DimsLike, Any, Mapping], - Tuple[DimsLike, Any, Mapping, Mapping], + ArrayLike = Any + VariableLike = Union[ + ArrayLike, + Tuple[DimsLike, ArrayLike], + Tuple[DimsLike, ArrayLike, Mapping], + Tuple[DimsLike, ArrayLike, Mapping, Mapping], ] - XarrayValue = Union[DataArray, Variable, VariableTuple] + XarrayValue = Union[DataArray, Variable, VariableLike] DatasetLike = Union[Dataset, Mapping[Hashable, XarrayValue]] CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame] CoercibleMapping = Union[Dataset, Mapping[Hashable, CoercibleValue]] @@ -565,7 +567,7 @@ def merge_core( def merge( - objects: Iterable[Union["DataArray", CoercibleMapping]], + objects: Iterable[Union["DataArray", "CoercibleMapping"]], compat: str = "no_conflicts", join: str = "outer", fill_value: object = dtypes.NA, From 3cc9b5c68be6a2e0d7e6ae241cdd8ff84dc8dc9d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 19 Sep 2019 09:11:54 -0700 Subject: [PATCH 5/5] fixup --- xarray/core/coordinates.py | 2 +- xarray/core/merge.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 0bc73e10785..430e507396b 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -30,7 +30,7 @@ _THIS_ARRAY = ReprObject("") -class AbstractCoordinates(Mapping[Hashable, "DataArray"]): +class Coordinates(Mapping[Hashable, "DataArray"]): __slots__ = () def __getitem__(self, key: Hashable) -> "DataArray": diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 7e718027014..11ef604a366 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -79,7 +79,10 @@ class MergeError(ValueError): def unique_variable( - name: Hashable, variables: List[Variable], compat: str = "broadcast_equals", equals: bool = None, + name: Hashable, + variables: List[Variable], + compat: str = "broadcast_equals", + equals: bool = None, ) -> Variable: """Return the unique variable from a list of variables or raise MergeError. @@ -129,8 +132,8 @@ def unique_variable( if not equals: raise MergeError( - "conflicting values for variable %r on objects to be combined. You can skip this check by specifying compat='override'." - % (name) + "conflicting values for variable {!r} on objects to be combined. " + "You can skip this check by specifying compat='override'.".format(name) ) if combine_method: