From 670d970b1e15d4972c55ca91e3f1340d1d46eafc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 16:53:35 +0200 Subject: [PATCH 01/50] 1. var_idx very slow --- xarray/core/concat.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index dcf2a23d311..bfffe0ed625 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -3,6 +3,7 @@ from collections.abc import Hashable, Iterable from typing import TYPE_CHECKING, Any, Union, cast, overload +import numpy as np import pandas as pd from xarray.core import dtypes, utils @@ -575,18 +576,23 @@ def get_indexes(name): for name in vars_order: if name in concat_over and name not in result_indexes: variables = [] - variable_index = [] + # variable_index = [] + variable_index = np.array([], dtype=np.intp) var_concat_dim_length = [] for i, ds in enumerate(datasets): if name in ds.variables: variables.append(ds[name].variable) # add to variable index, needed for reindexing - var_idx = [ - sum(concat_dim_lengths[:i]) + k - for k in range(concat_dim_lengths[i]) - ] - variable_index.extend(var_idx) - var_concat_dim_length.append(len(var_idx)) + dim_len = concat_dim_lengths[i] + var_idx = sum(concat_dim_lengths[:i]) + np.arange(dim_len) + variable_index = np.append(variable_index, var_idx) + var_concat_dim_length.append(dim_len) + # var_idx = [ + # sum(concat_dim_lengths[:i]) + k + # for k in range(concat_dim_lengths[i]) + # ] + # variable_index.extend(var_idx) + # var_concat_dim_length.append(len(var_idx)) else: # raise if coordinate not in all datasets if name in coord_names: From 1370a0edd00a93e41760ee884ee49f235e29c373 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 17:21:17 +0200 Subject: [PATCH 02/50] 2. slow any --- xarray/core/duck_array_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 84e66803fe8..fed5c64d83f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -192,7 +192,8 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays): + array_type_cupy = array_type("cupy") + if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] From 8ab83faa3f8d5be0a1a7526fead5b629d14a233f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 18:03:47 +0200 Subject: [PATCH 03/50] Add test --- asv_bench/benchmarks/combine.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index a4f8db2786b..8c05e3d1bff 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -3,7 +3,31 @@ import xarray as xr -class Combine: +class Combine1d: + """Benchmark concatenating and merging large datasets""" + + def setup(self): + """Create 2 datasets with two different variables""" + + t_size = 8000 + t = np.arange(t_size) + data = np.random.randn(t_size) + + self.dsA0 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t}, dims=("T"))} + ) + self.dsA1 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))} + ) + + def time_combine_by_coords(self): + """Also has to load and arrange t coordinate""" + datasets = [self.dsA0, self.dsA1] + + xr.combine_by_coords(datasets) + + +class Combine3d: """Benchmark concatenating and merging large datasets""" def setup(self): From b6e188180cc2816d84d427b28551869968382237 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 7 May 2023 16:04:30 +0000 Subject: [PATCH 04/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- asv_bench/benchmarks/combine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 8c05e3d1bff..2223996461d 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -13,9 +13,7 @@ def setup(self): t = np.arange(t_size) data = np.random.randn(t_size) - self.dsA0 = xr.Dataset( - {"A": xr.DataArray(data, coords={"T": t}, dims=("T"))} - ) + self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))}) self.dsA1 = xr.Dataset( {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))} ) From 487e9585e549e3751e0b2b20a59647e5f5cb5c14 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 18:15:46 +0200 Subject: [PATCH 05/50] 3. Slow array_type called multiple times --- xarray/core/pycompat.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 4a3f3638d14..d41c8f62af8 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -59,14 +59,26 @@ def __init__(self, mod: ModType) -> None: self.available = duck_array_module is not None +_cached_duck_array_modules: [dict[ModType, DuckArrayModule]] = {} + + +def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: + if mod not in _cached_duck_array_modules: + duckmod = DuckArrayModule(mod) + _cached_duck_array_modules[mod] = duckmod + return duckmod + else: + return _cached_duck_array_modules[mod] + + def array_type(mod: ModType) -> DuckArrayTypes: """Quick wrapper to get the array class of the module.""" - return DuckArrayModule(mod).type + return _get_cached_duck_array_module(mod).type def mod_version(mod: ModType) -> Version: """Quick wrapper to get the version of the module.""" - return DuckArrayModule(mod).version + return _get_cached_duck_array_module(mod).version def is_dask_collection(x): From fbb5430c2bc963af366a472b27e5284e4a0e40bd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 18:41:29 +0200 Subject: [PATCH 06/50] 4. Can use fastpath for variable.concat? --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c19cb21cba2..cc24870dee4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2101,7 +2101,7 @@ def concat( f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var.dims)}" ) - return cls(dims, data, attrs, encoding) + return cls(dims, data, attrs, encoding, fastpath=True) def equals(self, other, equiv=duck_array_ops.array_equiv): """True if two Variables have the same dimensions and values; From dc5f0e6091eb7a55589e1e23edd41feea6fed193 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 19:00:46 +0200 Subject: [PATCH 07/50] 5. slow init of pd.unique --- xarray/core/concat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index bfffe0ed625..76c59615445 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -544,7 +544,8 @@ def ensure_common_dims(vars, concat_dim_lengths): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the # concat dimension - common_dims = tuple(pd.unique([d for v in vars for d in v.dims])) + # common_dims = tuple(pd.unique([d for v in vars for d in v.dims])) + common_dims = tuple(dict.from_keys(d for v in vars for d in v.dims)) if dim not in common_dims: common_dims = (dim,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): From 38aa169f8dafd591286c4c83a7996689bab9c4df Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 19:08:15 +0200 Subject: [PATCH 08/50] typos --- xarray/core/concat.py | 2 +- xarray/core/pycompat.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 76c59615445..866fb3434ec 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -545,7 +545,7 @@ def ensure_common_dims(vars, concat_dim_lengths): # dimensions and the same shape for all of them except along the # concat dimension # common_dims = tuple(pd.unique([d for v in vars for d in v.dims])) - common_dims = tuple(dict.from_keys(d for v in vars for d in v.dims)) + common_dims = tuple(dict.fromkeys(d for v in vars for d in v.dims)) if dim not in common_dims: common_dims = (dim,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index d41c8f62af8..dbd91c0d2a5 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -59,7 +59,7 @@ def __init__(self, mod: ModType) -> None: self.available = duck_array_module is not None -_cached_duck_array_modules: [dict[ModType, DuckArrayModule]] = {} +_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {} def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: From 5fd6bcb8817cd4f12aa0187167380e04b9286ccf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 23:29:31 +0200 Subject: [PATCH 09/50] Update concat.py --- xarray/core/concat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 866fb3434ec..62fc82e31af 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -518,7 +518,7 @@ def _dataset_concat( if variables_to_merge: grouped = { k: v - for k, v in collect_variables_and_indexes(list(datasets)).items() + for k, v in collect_variables_and_indexes(datasets).items() if k in variables_to_merge } merged_vars, merged_indexes = merge_collected( @@ -570,7 +570,7 @@ def get_indexes(name): yield PandasIndex(data, dim, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing - concat_index = list(range(sum(concat_dim_lengths))) + concat_index = np.arange(sum(concat_dim_lengths)) # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. From 43dcff2cd7d6858df353d65518b539cbf5b6a951 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 May 2023 23:54:08 +0200 Subject: [PATCH 10/50] Update merge.py --- xarray/core/merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index bf7288ad7ed..a3287268407 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -306,7 +306,7 @@ def merge_collected( def collect_variables_and_indexes( - list_of_mappings: list[DatasetLike], + list_of_mappings: Iterable[DatasetLike], indexes: Mapping[Any, Any] | None = None, ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. From 7ef0e5d6830cc51b2f9f7e142b6c2acb3af64a08 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 8 May 2023 23:13:09 +0200 Subject: [PATCH 11/50] 6. Avoid recalculating in loops --- xarray/core/missing.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index d7f0be5fa08..4f976b46a6b 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -709,12 +709,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): ] new_x_arginds = [item for pair in new_x_arginds for item in pair] - args = ( - var, - range(ndim), - *x_arginds, - *new_x_arginds, - ) + args = (var, range(ndim), *x_arginds, *new_x_arginds) _, rechunked = da.unify_chunks(*args) @@ -722,15 +717,16 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + new_x0_chunks = new_x[0].chunks + new_x0_shape = new_x[0].shape + new_x0_chunks_is_not_none = new_x0_chunks is not None new_axes = { - ndim + i: new_x[0].chunks[i] - if new_x[0].chunks is not None - else new_x[0].shape[i] + ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] for i in range(new_x[0].ndim) } # if useful, re-use localize for each chunk of new_x - localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none # scipy.interpolate.interp1d always forces to float. # Use the same check for blockwise as well: From b2c067da61440dbd2c90780c1017a7492e60ee88 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 8 May 2023 23:19:45 +0200 Subject: [PATCH 12/50] 7. No need to transpose 1d arrays. --- xarray/core/missing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 4f976b46a6b..5d8309adda4 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -639,7 +639,7 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): var.transpose(*original_dims).data, x, destination, method, kwargs ) - result = Variable(new_dims, interped, attrs=var.attrs) + result = Variable(new_dims, interped, attrs=var.attrs, fastpath=True) # dimension of the output array out_dims: OrderedSet = OrderedSet() @@ -648,7 +648,8 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): out_dims.update(indexes_coords[d][1].dims) else: out_dims.add(d) - result = result.transpose(*out_dims) + if len(out_dims) > 1: + result = result.transpose(*out_dims) return result From ad048b63624b2f5d9501116ec4b53d5fb194e6f4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 10 May 2023 18:42:33 +0200 Subject: [PATCH 13/50] 8. speed up dask_dataframe --- xarray/core/dataset.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2336883d0b7..f6463a16310 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6403,6 +6403,12 @@ def to_dask_dataframe( columns.extend(k for k in self.coords if k not in self.dims) columns.extend(self.data_vars) + has_many_dims = len(dim_order) > 1 + if has_many_dims: + ds_chunks = self.chunks + else: + ds_chunks = {} + series_list = [] for name in columns: try: @@ -6422,8 +6428,13 @@ def to_dask_dataframe( if not is_duck_dask_array(var._data): var = var.chunk() - dask_array = var.set_dims(ordered_dims).chunk(self.chunks).data - series = dd.from_array(dask_array.reshape(-1), columns=[name]) + if has_many_dims: + # Broadcast then flatten the array: + var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks) + dask_array = var_new_dims._data.reshape(-1) + else: + dask_array = var._data + series = dd.from_array(dask_array, columns=[name]) series_list.append(series) df = dd.concat(series_list, axis=1) From d6098831db51162e17c9678f17acd36afd8c7f08 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 10 May 2023 19:20:34 +0200 Subject: [PATCH 14/50] Update dataset.py --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f6463a16310..e8b3baa3e38 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6403,7 +6403,7 @@ def to_dask_dataframe( columns.extend(k for k in self.coords if k not in self.dims) columns.extend(self.data_vars) - has_many_dims = len(dim_order) > 1 + has_many_dims = len(ordered_dims) > 1 if has_many_dims: ds_chunks = self.chunks else: From 4005f6f914c416738e44453299691619df776511 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 10 May 2023 21:02:16 +0200 Subject: [PATCH 15/50] Update dataset.py --- xarray/core/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e8b3baa3e38..67e40ab8bd4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6410,6 +6410,7 @@ def to_dask_dataframe( ds_chunks = {} series_list = [] + df_meta = pd.Dataframe() for name in columns: try: var = self.variables[name] @@ -6434,7 +6435,7 @@ def to_dask_dataframe( dask_array = var_new_dims._data.reshape(-1) else: dask_array = var._data - series = dd.from_array(dask_array, columns=[name]) + series = dd.from_dask_array(dask_array, columns=name, meta=df_meta) series_list.append(series) df = dd.concat(series_list, axis=1) From d23c8334c143b4ac516bfc268d2f4b0c3cbee1b4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 May 2023 05:39:29 +0200 Subject: [PATCH 16/50] Update dataset.py --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 67e40ab8bd4..442a359f6f7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6410,7 +6410,7 @@ def to_dask_dataframe( ds_chunks = {} series_list = [] - df_meta = pd.Dataframe() + df_meta = pd.DataFrame() for name in columns: try: var = self.variables[name] From 6c6b5c76e8d1ec51187ddfc6a048b82929e5eccb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 May 2023 06:10:03 +0200 Subject: [PATCH 17/50] Add dask combine test with many variables --- asv_bench/benchmarks/combine.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 2223996461d..bea2056dc58 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -2,6 +2,8 @@ import xarray as xr +from . import parameterized, randn, requires_dask + class Combine1d: """Benchmark concatenating and merging large datasets""" @@ -25,6 +27,24 @@ def time_combine_by_coords(self): xr.combine_by_coords(datasets) +class Combine1dDask: + """Benchmark concatenating and merging large datasets""" + + def setup(self): + """Create 2 datasets with two different variables""" + requires_dask() + + t_size = 8000 + var = xr.Variable(dims=("time",), data=np.random.randn(t_size)).chunk() + coord = xr.Variable(dims=("time",), data=np.random.randn(t_size)) + + data_vars = {f"long_name_{v}": ("time", var) for v in range(500)} + coords = {"time": ("time", coord)} + + self.dsA0 = xr.Dataset(data_vars, coords=coords) + self.dsA1 = xr.Dataset(data_vars) + + class Combine3d: """Benchmark concatenating and merging large datasets""" From 068ba556dc422c601214b17f3ee9a6ad6852bd1f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 May 2023 04:10:46 +0000 Subject: [PATCH 18/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- asv_bench/benchmarks/combine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index bea2056dc58..2bc5c5f00ae 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -2,7 +2,7 @@ import xarray as xr -from . import parameterized, randn, requires_dask +from . import requires_dask class Combine1d: From 5670331dfa021186319e2d37767e25f43068cf8f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 May 2023 06:16:28 +0200 Subject: [PATCH 19/50] Update combine.py --- asv_bench/benchmarks/combine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 2bc5c5f00ae..c91e6ba676d 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -42,7 +42,7 @@ def setup(self): coords = {"time": ("time", coord)} self.dsA0 = xr.Dataset(data_vars, coords=coords) - self.dsA1 = xr.Dataset(data_vars) + self.dsA1 = xr.Dataset(data_vars, coords=coords) class Combine3d: From b11afe823487afbec4131ff2f8a10d2aa873d546 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 May 2023 06:17:01 +0200 Subject: [PATCH 20/50] Update combine.py --- asv_bench/benchmarks/combine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index c91e6ba676d..274570a2cc3 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -27,7 +27,7 @@ def time_combine_by_coords(self): xr.combine_by_coords(datasets) -class Combine1dDask: +class Combine1dDask(Combine1d): """Benchmark concatenating and merging large datasets""" def setup(self): From e1938a8457bbc5212d8bcb9dc2bc2418a5e70f14 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 11 May 2023 22:30:34 +0200 Subject: [PATCH 21/50] Update combine.py --- asv_bench/benchmarks/combine.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 274570a2cc3..318b9fda070 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -35,14 +35,13 @@ def setup(self): requires_dask() t_size = 8000 - var = xr.Variable(dims=("time",), data=np.random.randn(t_size)).chunk() - coord = xr.Variable(dims=("time",), data=np.random.randn(t_size)) + t = np.arange(t_size) + var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk() - data_vars = {f"long_name_{v}": ("time", var) for v in range(500)} - coords = {"time": ("time", coord)} + data_vars = {f"long_name_{v}": ("T", var) for v in range(500)} - self.dsA0 = xr.Dataset(data_vars, coords=coords) - self.dsA1 = xr.Dataset(data_vars, coords=coords) + self.dsA0 = xr.Dataset(data_vars, coords={"T": t}) + self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size}) class Combine3d: From 43fd7a2425891b259028e5ac20ebf0067ffee8eb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 11:42:16 +0200 Subject: [PATCH 22/50] list not needed --- xarray/core/combine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 8106c295f5a..cee27300beb 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -970,10 +970,9 @@ def combine_by_coords( # Perform the multidimensional combine on each group of data variables # before merging back together - concatenated_grouped_by_data_vars = [] - for vars, datasets_with_same_vars in grouped_by_vars: - concatenated = _combine_single_variable_hypercube( - list(datasets_with_same_vars), + concatenated_grouped_by_data_vars = tuple( + _combine_single_variable_hypercube( + tuple(datasets_with_same_vars), fill_value=fill_value, data_vars=data_vars, coords=coords, @@ -981,7 +980,8 @@ def combine_by_coords( join=join, combine_attrs=combine_attrs, ) - concatenated_grouped_by_data_vars.append(concatenated) + for vars, datasets_with_same_vars in grouped_by_vars + ) return merge( concatenated_grouped_by_data_vars, From a59635bcb5ca3e6eab94cb5a11ed1b3e5e07eeef Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 11:44:31 +0200 Subject: [PATCH 23/50] dim is usually string, might be faster to check for that --- xarray/core/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index f6abcba1ff0..4ce191ab3b4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -209,7 +209,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, . int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if isinstance(dim, Iterable) and not isinstance(dim, str): + if not isinstance(dim, str) and isinstance(dim, Iterable): return tuple(self._get_axis_num(d) for d in dim) else: return self._get_axis_num(dim) From 70be8c9b1a32cb5fb0c70edf3fa5683ff0a5e2d0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 11:45:44 +0200 Subject: [PATCH 24/50] first_var.dims doesn't change and can be defined 1 time --- xarray/core/variable.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index cc24870dee4..885ca93f5c3 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2073,12 +2073,13 @@ def concat( # twice variables = list(variables) first_var = variables[0] + first_var_dims = first_var.dims - arrays = [v.data for v in variables] + arrays = [v._data for v in variables] - if dim in first_var.dims: + if dim in first_var_dims: axis = first_var.get_axis_num(dim) - dims = first_var.dims + dims = first_var_dims data = duck_array_ops.concatenate(arrays, axis=axis) if positions is not None: # TODO: deprecate this option -- we don't need it for groupby @@ -2087,7 +2088,7 @@ def concat( data = duck_array_ops.take(data, indices, axis=axis) else: axis = 0 - dims = (dim,) + first_var.dims + dims = (dim,) + first_var_dims data = duck_array_ops.stack(arrays, axis=axis) attrs = merge_attrs( @@ -2096,9 +2097,9 @@ def concat( encoding = dict(first_var.encoding) if not shortcut: for var in variables: - if var.dims != first_var.dims: + if var.dims != first_var_dims: raise ValueError( - f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var.dims)}" + f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var_dims)}" ) return cls(dims, data, attrs, encoding, fastpath=True) From 68a0c083d3b0ff7889f4a51b7eaca47cadfbf3e5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 13:36:24 +0200 Subject: [PATCH 25/50] mask bad points rather than append good points --- xarray/core/concat.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 62fc82e31af..a677d25deae 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -571,6 +571,8 @@ def get_indexes(name): # create concatenation index, needed for later reindexing concat_index = np.arange(sum(concat_dim_lengths)) + concat_index_size = concat_index.size + variable_index_mask = np.ones(concat_index_size, dtype=bool) # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. @@ -578,15 +580,15 @@ def get_indexes(name): if name in concat_over and name not in result_indexes: variables = [] # variable_index = [] - variable_index = np.array([], dtype=np.intp) + mask = variable_index_mask.copy() var_concat_dim_length = [] for i, ds in enumerate(datasets): if name in ds.variables: variables.append(ds[name].variable) # add to variable index, needed for reindexing dim_len = concat_dim_lengths[i] - var_idx = sum(concat_dim_lengths[:i]) + np.arange(dim_len) - variable_index = np.append(variable_index, var_idx) + # var_idx = sum(concat_dim_lengths[:i]) + np.arange(dim_len) + # variable_index = np.append(variable_index, var_idx) var_concat_dim_length.append(dim_len) # var_idx = [ # sum(concat_dim_lengths[:i]) + k @@ -600,6 +602,11 @@ def get_indexes(name): raise ValueError( f"coordinate {name!r} not present in all datasets." ) + start = sum(concat_dim_lengths[:i]) + end = start + concat_dim_lengths[i] + mask[slice(start, end)] = False + + variable_index = concat_index[mask] vars = ensure_common_dims(variables, var_concat_dim_length) # Try to concatenate the indexes, concatenate the variables when no index @@ -630,7 +637,7 @@ def get_indexes(name): vars, dim, positions, combine_attrs=combine_attrs ) # reindex if variable is not present in all datasets - if len(variable_index) < len(concat_index): + if len(variable_index) < concat_index_size: combined_var = reindex_variables( variables={name: combined_var}, dim_pos_indexers={ From 2ebf78a1f0e0138a83bcd2d8306200051b45cdc1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 13:37:09 +0200 Subject: [PATCH 26/50] reduce duplicated code --- xarray/core/dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 442a359f6f7..f20cbbebc6d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1356,11 +1356,14 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: needed_dims = set(variable.dims) + coord_name = self._coord_names coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: - coords[k] = self.variables[k] + if k in coord_name: + var = self._variables[k] + if set(var.dims) <= needed_dims: + coords[k] = var indexes = filter_indexes_from_coords(self._indexes, set(coords)) From dd813252144d7f5ecd833ce3b91ee8f8627c943e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 13:42:58 +0200 Subject: [PATCH 27/50] don't think id() is required here. --- xarray/core/indexes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 93e9e535fe3..d44c2b51d3b 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1495,11 +1495,11 @@ def filter_indexes_from_coords( of coordinate names. """ - filtered_indexes: dict[Any, Index] = dict(**indexes) + filtered_indexes: dict[Any, Index] = dict(indexes) index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) for name, idx in indexes.items(): - index_coord_names[id(idx)].add(name) + index_coord_names[idx].add(name) for idx_coord_names in index_coord_names.values(): if not idx_coord_names <= filtered_coord_names: From 8d3d1528cc71403f3f9f5b48b41874e7bd9cf402 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:09:44 +0200 Subject: [PATCH 28/50] get dtype directly instead of through result_dtype --- xarray/core/dtypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 4d8583cfe65..9d7242233bb 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -156,7 +156,7 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def result_type(*arrays_and_dtypes): +def result_type(*arrays_and_dtypes) -> np.dtype: """Like np.result_type, but with type promotion rules matching pandas. Examples of changed behavior: @@ -172,7 +172,7 @@ def result_type(*arrays_and_dtypes): ------- numpy.dtype for the result. """ - types = {np.result_type(t).type for t in arrays_and_dtypes} + types = {t.dtype.type for t in arrays_and_dtypes} for left, right in PROMOTE_TO_OBJECT: if any(issubclass(t, left) for t in types) and any( From 476b1e086ae18694a5fa6c6eda567e438c4ce512 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:12:08 +0200 Subject: [PATCH 29/50] seems better to delete rather than append, --- xarray/core/concat.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index a677d25deae..9167a38112a 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -576,32 +576,22 @@ def get_indexes(name): # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. + ndatasets = len(datasets) for name in vars_order: if name in concat_over and name not in result_indexes: variables = [] - # variable_index = [] mask = variable_index_mask.copy() - var_concat_dim_length = [] + var_concat_dim_length = concat_dim_lengths.copy() for i, ds in enumerate(datasets): if name in ds.variables: variables.append(ds[name].variable) - # add to variable index, needed for reindexing - dim_len = concat_dim_lengths[i] - # var_idx = sum(concat_dim_lengths[:i]) + np.arange(dim_len) - # variable_index = np.append(variable_index, var_idx) - var_concat_dim_length.append(dim_len) - # var_idx = [ - # sum(concat_dim_lengths[:i]) + k - # for k in range(concat_dim_lengths[i]) - # ] - # variable_index.extend(var_idx) - # var_concat_dim_length.append(len(var_idx)) else: # raise if coordinate not in all datasets if name in coord_names: raise ValueError( f"coordinate {name!r} not present in all datasets." ) + del var_concat_dim_length[i] start = sum(concat_dim_lengths[:i]) end = start + concat_dim_lengths[i] mask[slice(start, end)] = False @@ -613,7 +603,7 @@ def get_indexes(name): # is found on all datasets. indexes: list[Index] = list(get_indexes(name)) if indexes: - if len(indexes) < len(datasets): + if len(indexes) < ndatasets: raise ValueError( f"{name!r} must have either an index or no index in all datasets, " f"found {len(indexes)}/{len(datasets)} datasets with an index." From 674638a33b3f5f998c31fb40bc75f87545e775de Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:13:08 +0200 Subject: [PATCH 30/50] use internal fastpath if it's a dataset, values should be fine then --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f20cbbebc6d..3bd1011620a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -607,7 +607,7 @@ def __init__( ) if isinstance(coords, Dataset): - coords = coords.variables + coords = coords._variables variables, coord_names, dims, indexes, _ = merge_data_and_coords( data_vars, coords, compat="broadcast_equals" From 05206f8dc9e5079ec5d640ab952eb553157df144 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:13:33 +0200 Subject: [PATCH 31/50] Change isinstance order. --- xarray/core/merge.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index a3287268407..747263ddaf5 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -556,7 +556,12 @@ def merge_coords( return variables, out_indexes -def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): +def merge_data_and_coords( + data_vars: Mapping[Any, Any], + coords: Mapping[Any, Any], + compat: CompatOptions = "broadcast_equals", + join: JoinOptions = "outer", +) -> _MergeResult: """Used in Dataset.__init__.""" indexes, coords = _create_indexes_from_coords(coords, data_vars) objects = [data_vars, coords] @@ -570,7 +575,9 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou ) -def _create_indexes_from_coords(coords, data_vars=None): +def _create_indexes_from_coords( + coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None +) -> tuple[dict, dict]: """Maybe create default indexes from a mapping of coordinates. Return those indexes and updated coordinates. @@ -1035,7 +1042,7 @@ def dataset_merge_method( # method due for backwards compatibility # TODO: consider deprecating it? - if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str): + if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable): overwrite_vars = set(overwrite_vars) else: overwrite_vars = {overwrite_vars} From e0dae6db1e4cda80aea9f4a935a310e30c807364 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:14:19 +0200 Subject: [PATCH 32/50] use fastpath if already xarray objtect --- xarray/core/variable.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 885ca93f5c3..48644b969d3 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -232,7 +232,7 @@ def _possibly_convert_datetime_or_timedelta_index(data): return data -def as_compatible_data(data, fastpath=False): +def as_compatible_data(data, fastpath: bool = False): """Prepare and wrap data to put in a Variable. - If data does not have the necessary attributes, convert it to ndarray. @@ -250,7 +250,7 @@ def as_compatible_data(data, fastpath=False): from xarray.core.dataarray import DataArray if isinstance(data, (Variable, DataArray)): - return data.data + return data._data if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): data = _possibly_convert_datetime_or_timedelta_index(data) @@ -670,7 +670,8 @@ def dims(self, value: str | Iterable[Hashable]) -> None: def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: if isinstance(dims, str): dims = (dims,) - dims = tuple(dims) + else: + dims = tuple(dims) if len(dims) != self.ndim: raise ValueError( f"dimensions {dims} must have the same length as the " From 9cc6c2df0fd4ef315a7644bb5c4daedb79f3b543 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:30:28 +0200 Subject: [PATCH 33/50] Update variable.py --- xarray/core/variable.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 48644b969d3..e489ed15840 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -247,10 +247,13 @@ def as_compatible_data(data, fastpath: bool = False): # can't use fastpath (yet) for scalars return _maybe_wrap_data(data) + if isinstance(data, Variable): + return data._data + from xarray.core.dataarray import DataArray - if isinstance(data, (Variable, DataArray)): - return data._data + if isinstance(data, DataArray): + return data.variable._data if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): data = _possibly_convert_datetime_or_timedelta_index(data) From f03154c979b5ab4b4c3c44e6247f3b5e5675a78c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 21:42:48 +0200 Subject: [PATCH 34/50] Update dtypes.py --- xarray/core/dtypes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 9d7242233bb..17417e724a6 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -37,11 +37,11 @@ def __eq__(self, other): # instead of following NumPy's own type-promotion rules. These type promotion # rules match pandas instead. For reference, see the NumPy type hierarchy: # https://numpy.org/doc/stable/reference/arrays.scalars.html -PROMOTE_TO_OBJECT = [ - {np.number, np.character}, # numpy promotes to character - {np.bool_, np.character}, # numpy promotes to character - {np.bytes_, np.unicode_}, # numpy promotes to unicode -] +PROMOTE_TO_OBJECT: tuple[tuple[np.generic, np.generic]] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.unicode_), # numpy promotes to unicode +) def maybe_promote(dtype): From 529e38692d31dec6f9cc2c32186b8b485d61b1ae Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 22:16:09 +0200 Subject: [PATCH 35/50] typing fixes --- xarray/core/dtypes.py | 2 +- xarray/core/merge.py | 14 +++++++++----- xarray/core/variable.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 17417e724a6..23d797e539f 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -37,7 +37,7 @@ def __eq__(self, other): # instead of following NumPy's own type-promotion rules. These type promotion # rules match pandas instead. For reference, see the NumPy type hierarchy: # https://numpy.org/doc/stable/reference/arrays.scalars.html -PROMOTE_TO_OBJECT: tuple[tuple[np.generic, np.generic]] = ( +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( (np.number, np.character), # numpy promotes to character (np.bool_, np.character), # numpy promotes to character (np.bytes_, np.unicode_), # numpy promotes to unicode diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 747263ddaf5..2578bd8198b 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -195,12 +195,12 @@ def _assert_prioritized_valid( def merge_collected( - grouped: dict[Hashable, list[MergeElement]], + grouped: dict[Any, list[MergeElement]], prioritized: Mapping[Any, MergeElement] | None = None, compat: CompatOptions = "minimal", combine_attrs: CombineAttrsOptions = "override", - equals: dict[Hashable, bool] | None = None, -) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: + equals: dict[Any, bool] | None = None, +) -> tuple[dict[Any, Variable], dict[Any, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. Parameters @@ -612,7 +612,11 @@ def _create_indexes_from_coords( return indexes, updated_coords -def assert_valid_explicit_coords(variables, dims, explicit_coords): +def assert_valid_explicit_coords( + variables: Mapping[Any, Any], + dims: Mapping[Any, int], + explicit_coords: Iterable[Hashable], +) -> None: """Validate explicit coordinate names/dims. Raise a MergeError if an explicit coord shares a name with a dimension @@ -695,7 +699,7 @@ def merge_core( join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", priority_arg: int | None = None, - explicit_coords: Sequence | None = None, + explicit_coords: Iterable[Hashable] | None = None, indexes: Mapping[Any, Any] | None = None, fill_value: object = dtypes.NA, ) -> _MergeResult: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e489ed15840..9a6f742e19e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -3226,7 +3226,7 @@ def concat( return Variable.concat(variables, dim, positions, shortcut, combine_attrs) -def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: +def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Any, int]: """Calculate the dimensions corresponding to a set of variables. Returns dictionary mapping from dimension names to sizes. Raises ValueError From b7492ca1fbbedf11ab623fe0e633b2a64ccc329b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 22:38:49 +0200 Subject: [PATCH 36/50] more typing fixes --- xarray/core/dtypes.py | 6 ++++-- xarray/tests/test_dtypes.py | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 23d797e539f..f5a184289c9 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -156,7 +156,9 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def result_type(*arrays_and_dtypes) -> np.dtype: +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype: """Like np.result_type, but with type promotion rules matching pandas. Examples of changed behavior: @@ -172,7 +174,7 @@ def result_type(*arrays_and_dtypes) -> np.dtype: ------- numpy.dtype for the result. """ - types = {t.dtype.type for t in arrays_and_dtypes} + types = {np.result_type(t).type for t in arrays_and_dtypes} for left, right in PROMOTE_TO_OBJECT: if any(issubclass(t, left) for t in types) and any( diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 1c942a1e6c8..490520c8f54 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -18,17 +18,17 @@ ([np.bytes_, np.unicode_], np.object_), ], ) -def test_result_type(args, expected): +def test_result_type(args, expected) -> None: actual = dtypes.result_type(*args) assert actual == expected -def test_result_type_scalar(): +def test_result_type_scalar() -> None: actual = dtypes.result_type(np.arange(3, dtype=np.float32), np.nan) assert actual == np.float32 -def test_result_type_dask_array(): +def test_result_type_dask_array() -> None: # verify it works without evaluating dask arrays da = pytest.importorskip("dask.array") dask = pytest.importorskip("dask") @@ -50,7 +50,7 @@ def error(): @pytest.mark.parametrize("obj", [1.0, np.inf, "ab", 1.0 + 1.0j, True]) -def test_inf(obj): +def test_inf(obj) -> None: assert dtypes.INF > obj assert dtypes.NINF < obj @@ -85,7 +85,7 @@ def test_inf(obj): ("V", (np.dtype("O"), "nan")), # dtype('V') ], ) -def test_maybe_promote(kind, expected): +def test_maybe_promote(kind, expected) -> None: # 'g': np.float128 is not tested : not available on all platforms # 'G': np.complex256 is not tested : not available on all platforms @@ -94,7 +94,7 @@ def test_maybe_promote(kind, expected): assert str(actual[1]) == expected[1] -def test_nat_types_membership(): +def test_nat_types_membership() -> None: assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES assert np.float64 not in dtypes.NAT_TYPES From cf51f162d2102d64009ab4cce633aaf825cab415 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 22:55:12 +0200 Subject: [PATCH 37/50] test undoing as_compatible_data --- xarray/core/variable.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9a6f742e19e..dac576c4e91 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -247,13 +247,10 @@ def as_compatible_data(data, fastpath: bool = False): # can't use fastpath (yet) for scalars return _maybe_wrap_data(data) - if isinstance(data, Variable): - return data._data - from xarray.core.dataarray import DataArray - if isinstance(data, DataArray): - return data.variable._data + if isinstance(data, (Variable, DataArray)): + return data.data if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): data = _possibly_convert_datetime_or_timedelta_index(data) From 4b6a4c6cfddc4b6e972b99185ed078760d7b0e8e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 14 May 2023 23:03:34 +0200 Subject: [PATCH 38/50] undo concat_dim_length deletion --- xarray/core/concat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 9167a38112a..fb64e41a381 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -581,17 +581,17 @@ def get_indexes(name): if name in concat_over and name not in result_indexes: variables = [] mask = variable_index_mask.copy() - var_concat_dim_length = concat_dim_lengths.copy() + var_concat_dim_length = [] for i, ds in enumerate(datasets): if name in ds.variables: variables.append(ds[name].variable) + var_concat_dim_length.append(concat_dim_lengths[i]) else: # raise if coordinate not in all datasets if name in coord_names: raise ValueError( f"coordinate {name!r} not present in all datasets." ) - del var_concat_dim_length[i] start = sum(concat_dim_lengths[:i]) end = start + concat_dim_lengths[i] mask[slice(start, end)] = False From 2cd984d7d7cf5257084ac3ad661979aed024b1ad Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 15 May 2023 00:41:48 +0200 Subject: [PATCH 39/50] Update xarray/core/concat.py --- xarray/core/concat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index fb64e41a381..705738f86a8 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -544,7 +544,6 @@ def ensure_common_dims(vars, concat_dim_lengths): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the # concat dimension - # common_dims = tuple(pd.unique([d for v in vars for d in v.dims])) common_dims = tuple(dict.fromkeys(d for v in vars for d in v.dims)) if dim not in common_dims: common_dims = (dim,) + common_dims From 86eb72a228b8a9a4ed23f67e57c1bf2b44eb6539 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 16 May 2023 23:07:42 +0200 Subject: [PATCH 40/50] Remove .copy and sum --- xarray/core/concat.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 705738f86a8..b2074df64b7 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -569,7 +569,8 @@ def get_indexes(name): yield PandasIndex(data, dim, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing - concat_index = np.arange(sum(concat_dim_lengths)) + file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) + concat_index = np.arange(file_start_indexes[-1]) concat_index_size = concat_index.size variable_index_mask = np.ones(concat_index_size, dtype=bool) @@ -579,7 +580,7 @@ def get_indexes(name): for name in vars_order: if name in concat_over and name not in result_indexes: variables = [] - mask = variable_index_mask.copy() + variable_index_mask.fill(True) var_concat_dim_length = [] for i, ds in enumerate(datasets): if name in ds.variables: @@ -591,11 +592,13 @@ def get_indexes(name): raise ValueError( f"coordinate {name!r} not present in all datasets." ) - start = sum(concat_dim_lengths[:i]) - end = start + concat_dim_lengths[i] - mask[slice(start, end)] = False - variable_index = concat_index[mask] + # Mask out the indexes without the name:: + start = file_start_indexes[i] + end = file_start_indexes[i + 1] + variable_index_mask[slice(start, end)] = False + + variable_index = concat_index[variable_index_mask] vars = ensure_common_dims(variables, var_concat_dim_length) # Try to concatenate the indexes, concatenate the variables when no index From ee2c3f68c51a2ab33af6200926b00ce808f67823 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 18 May 2023 23:58:30 +0200 Subject: [PATCH 41/50] Update concat.py --- xarray/core/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index b2074df64b7..f61f0e788a9 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -593,7 +593,7 @@ def get_indexes(name): f"coordinate {name!r} not present in all datasets." ) - # Mask out the indexes without the name:: + # Mask out the indexes without the name: start = file_start_indexes[i] end = file_start_indexes[i + 1] variable_index_mask[slice(start, end)] = False From 1e514f6c0c281a288a4622b3ae1303f249f44800 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 21 May 2023 11:17:50 +0200 Subject: [PATCH 42/50] Use OrderedSet --- xarray/core/concat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index f61f0e788a9..ba514c6a0f9 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -544,7 +544,8 @@ def ensure_common_dims(vars, concat_dim_lengths): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the # concat dimension - common_dims = tuple(dict.fromkeys(d for v in vars for d in v.dims)) + common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) + if dim not in common_dims: common_dims = (dim,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): From 0d0b76eda08b4907adaa26b8d2a5bbfaa7722f96 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 25 May 2023 22:48:11 +0200 Subject: [PATCH 43/50] Apply suggestions from code review --- asv_bench/benchmarks/combine.py | 6 +++--- xarray/core/merge.py | 2 +- xarray/core/variable.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 318b9fda070..772d888306c 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -8,7 +8,7 @@ class Combine1d: """Benchmark concatenating and merging large datasets""" - def setup(self): + def setup(self) -> None: """Create 2 datasets with two different variables""" t_size = 8000 @@ -20,7 +20,7 @@ def setup(self): {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))} ) - def time_combine_by_coords(self): + def time_combine_by_coords(self) -> None: """Also has to load and arrange t coordinate""" datasets = [self.dsA0, self.dsA1] @@ -30,7 +30,7 @@ def time_combine_by_coords(self): class Combine1dDask(Combine1d): """Benchmark concatenating and merging large datasets""" - def setup(self): + def setup(self) -> None: """Create 2 datasets with two different variables""" requires_dask() diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 2578bd8198b..56e51256ba1 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -200,7 +200,7 @@ def merge_collected( compat: CompatOptions = "minimal", combine_attrs: CombineAttrsOptions = "override", equals: dict[Any, bool] | None = None, -) -> tuple[dict[Any, Variable], dict[Any, Index]]: +) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. Parameters diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2e59d4d3221..83ccbc9a1cf 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -3260,7 +3260,7 @@ def concat( return Variable.concat(variables, dim, positions, shortcut, combine_attrs) -def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Any, int]: +def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: """Calculate the dimensions corresponding to a set of variables. Returns dictionary mapping from dimension names to sizes. Raises ValueError From 15e2783087546161b9d88c11cdc1ac9932488340 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 30 May 2023 21:46:54 +0200 Subject: [PATCH 44/50] Update whats-new.rst --- doc/whats-new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2c9d72ebfba..6d72a297dd7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,7 +26,8 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ - +- Improve concatenation performance (:issue:`7833`, :pull:`7824`). + By `Jimmy Westling `_. Deprecations ~~~~~~~~~~~~ From 6f86e986acb1fcd185a75de74a90a01204acfc4c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 31 May 2023 17:58:18 +0200 Subject: [PATCH 45/50] Update xarray/core/concat.py --- xarray/core/concat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index ba514c6a0f9..07c38a18696 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -545,7 +545,6 @@ def ensure_common_dims(vars, concat_dim_lengths): # dimensions and the same shape for all of them except along the # concat dimension common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) - if dim not in common_dims: common_dims = (dim,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): From db25a0b95f60d7a558663cd48df521bb6457ecbc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 31 May 2023 18:53:55 +0200 Subject: [PATCH 46/50] no need to check arrays if cupy isnt even installed --- xarray/core/duck_array_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 79a7ba007b4..4f245e59f73 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -195,7 +195,9 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" array_type_cupy = array_type("cupy") - if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): + if array_type_cupy and any( + isinstance(x, array_type_cupy) for x in scalars_or_arrays + ): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] From 629cdeac781779b4619b0b27b50fc036312fb3a1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:16:26 +0200 Subject: [PATCH 47/50] Update whats-new.rst --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8f5066e79f1..b03388cb551 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,8 +28,7 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ -- Improve concatenation performance (:issue:`7833`, :pull:`7824`). - By `Jimmy Westling `_. + Deprecations ~~~~~~~~~~~~ @@ -38,7 +37,8 @@ Deprecations Performance ~~~~~~~~~~~ - +- Improve concatenation performance (:issue:`7833`, :pull:`7824`). + By `Jimmy Westling `_. Bug fixes ~~~~~~~~~ From cfaa86662fd913239c79b3817e1e7269adb429dc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:22:29 +0200 Subject: [PATCH 48/50] Add concat comment --- xarray/core/concat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 07c38a18696..d7aad8c7188 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -580,6 +580,8 @@ def get_indexes(name): for name in vars_order: if name in concat_over and name not in result_indexes: variables = [] + # Initialize the mask to all True then set False if any name is missing in + # the datasets: variable_index_mask.fill(True) var_concat_dim_length = [] for i, ds in enumerate(datasets): From fe62336466d86b1501bb597a4775785c79ac50a2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:34:19 +0200 Subject: [PATCH 49/50] minimize diff --- xarray/core/dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2bf3622fc9c..81860bede95 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1396,14 +1396,11 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: needed_dims = set(variable.dims) - coord_name = self._coord_names coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in coord_name: - var = self._variables[k] - if set(var.dims) <= needed_dims: - coords[k] = var + if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) From 4e905043bcdad29a0ca4aa46b3682262f701fecf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:37:03 +0200 Subject: [PATCH 50/50] revert sketchy --- xarray/core/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index d44c2b51d3b..9ee9bc374d4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1499,7 +1499,7 @@ def filter_indexes_from_coords( index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) for name, idx in indexes.items(): - index_coord_names[idx].add(name) + index_coord_names[id(idx)].add(name) for idx_coord_names in index_coord_names.values(): if not idx_coord_names <= filtered_coord_names: