Skip to content

Commit 0aee4fe

Browse files
authored
Optimize some copying (#7209)
* pass some deep copy args along * fix mypy * add swap_dims benchmark * hopefully fix asv
1 parent e4fe194 commit 0aee4fe

File tree

4 files changed

+51
-8
lines changed

4 files changed

+51
-8
lines changed

asv_bench/benchmarks/renaming.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
3+
import xarray as xr
4+
5+
6+
class SwapDims:
7+
param_names = ["size"]
8+
params = [[int(1e3), int(1e5), int(1e7)]]
9+
10+
def setup(self, size: int) -> None:
11+
self.ds = xr.Dataset(
12+
{"a": (("x", "t"), np.ones((size, 2)))},
13+
coords={
14+
"x": np.arange(size),
15+
"y": np.arange(size),
16+
"z": np.arange(size),
17+
"x2": ("x", np.arange(size)),
18+
"y2": ("y", np.arange(size)),
19+
"z2": ("z", np.arange(size)),
20+
},
21+
)
22+
23+
def time_swap_dims(self, size: int) -> None:
24+
self.ds.swap_dims({"x": "xn", "y": "yn", "z": "zn"})
25+
26+
def time_swap_dims_newindex(self, size: int) -> None:
27+
self.ds.swap_dims({"x": "x2", "y": "y2", "z": "z2"})

xarray/core/alignment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def override_indexes(self) -> None:
474474
if obj_idx is not None:
475475
for name, var in self.aligned_index_vars[key].items():
476476
new_indexes[name] = aligned_idx
477-
new_variables[name] = var.copy()
477+
new_variables[name] = var.copy(deep=self.copy)
478478

479479
objects[i + 1] = obj._overwrite_indexes(new_indexes, new_variables)
480480

@@ -514,7 +514,7 @@ def _get_indexes_and_vars(
514514
if obj_idx is not None:
515515
for name, var in index_vars.items():
516516
new_indexes[name] = aligned_idx
517-
new_variables[name] = var.copy()
517+
new_variables[name] = var.copy(deep=self.copy)
518518

519519
return new_indexes, new_variables
520520

xarray/core/indexes.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,12 @@ def __copy__(self) -> Index:
117117
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index:
118118
return self._copy(deep=True, memo=memo)
119119

120-
def copy(self, deep: bool = True) -> Index:
120+
def copy(self: T_Index, deep: bool = True) -> T_Index:
121121
return self._copy(deep=deep)
122122

123-
def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Index:
123+
def _copy(
124+
self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None
125+
) -> T_Index:
124126
cls = self.__class__
125127
copied = cls.__new__(cls)
126128
if deep:
@@ -269,6 +271,9 @@ def get_indexer_nd(index, labels, method=None, tolerance=None):
269271
return indexer
270272

271273

274+
T_PandasIndex = TypeVar("T_PandasIndex", bound="PandasIndex")
275+
276+
272277
class PandasIndex(Index):
273278
"""Wrap a pandas.Index as an xarray compatible index."""
274279

@@ -532,8 +537,11 @@ def rename(self, name_dict, dims_dict):
532537
new_dim = dims_dict.get(self.dim, self.dim)
533538
return self._replace(index, dim=new_dim)
534539

535-
def copy(self, deep=True):
540+
def _copy(
541+
self: T_PandasIndex, deep: bool = True, memo: dict[int, Any] | None = None
542+
) -> T_PandasIndex:
536543
if deep:
544+
# pandas is not using the memo
537545
index = self.index.copy(deep=True)
538546
else:
539547
# index will be copied in constructor
@@ -1265,11 +1273,19 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]:
12651273
return Indexes(indexes, self._variables)
12661274

12671275
def copy_indexes(
1268-
self, deep: bool = True
1276+
self, deep: bool = True, memo: dict[int, Any] | None = None
12691277
) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]:
12701278
"""Return a new dictionary with copies of indexes, preserving
12711279
unique indexes.
12721280
1281+
Parameters
1282+
----------
1283+
deep : bool, default: True
1284+
Whether the indexes are deep or shallow copied onto the new object.
1285+
memo : dict if object id to copied objects or None, optional
1286+
To prevent infinite recursion deepcopy stores all copied elements
1287+
in this dict.
1288+
12731289
"""
12741290
new_indexes = {}
12751291
new_index_vars = {}
@@ -1285,7 +1301,7 @@ def copy_indexes(
12851301
else:
12861302
convert_new_idx = False
12871303

1288-
new_idx = idx.copy(deep=deep)
1304+
new_idx = idx._copy(deep=deep, memo=memo)
12891305
idx_vars = idx.create_variables(coords)
12901306

12911307
if convert_new_idx:

xarray/core/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2999,7 +2999,7 @@ def _data_equals(self, other):
29992999

30003000
def to_index_variable(self) -> IndexVariable:
30013001
"""Return this variable as an xarray.IndexVariable"""
3002-
return self.copy()
3002+
return self.copy(deep=False)
30033003

30043004
to_coord = utils.alias(to_index_variable, "to_coord")
30053005

0 commit comments

Comments
 (0)