Skip to content

Commit 1ab02b4

Browse files
andersy005Illviljanpre-commit-ci[bot]dcherian
authored
[namedarray] split .set_dims() into .expand_dims() and broadcast_to() (#8380)
* add `.set_dims()`, `.transpose()` and `.T` to namedarray * more typying fixes * more typing fixes * override set_dims for IndexVariable * fix dims * split `.set_dims()` into `.expand_dims()` and `broadcast_to()` * more typing fixes * update whats-new * update tests * doc fixes * update whats-new * keep `.set_dims()` on `Variable()` * update docs * revert to set_dims * revert to .set_dims on Variable * Update xarray/namedarray/core.py Co-authored-by: Illviljan <[email protected]> * restore .transpose on variable * revert to set_dims in Variable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix docstring * update test_namedarray * update tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Deepak Cherian <[email protected]> * fix formatting issue * fix tests * update expand_dims * update tests * update tests * remove unnecessary guard conditions * Update type hints in NamedArray class and test cases * Refactor NamedArray T property to handle non-2D arrays * Reverse the order of dimensions in x.T * Refactor broadcasting and dimension expansion in NamedArray * update docstring * add todo item * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use comprehension * use dim * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix imports * formatting only * Apply suggestions from code review Co-authored-by: Deepak Cherian <[email protected]> * [skip-rtd] fix indentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor expand_dims to simplify API simplified the `expand_dims` method of the NamedArray class by removing support for multiple and keyword arguments for new dimensions. the change updates the method signature to only accept a single dimension name, enhancing clarity and maintainability. this shift focuses on the common use case of adding a single dimension and moves extended functionality to a separate API. * fix type hint for `dim` parameter in test. * fix typing issues * fix UnboundLocalError: local variable 'flattened_dims' referenced before assignment * fix type hint * ignore typing * update whats-new * adjust the `broadcast_to` method to prohibit adding new dimensions, allowing only the broadcasting of existing ones --------- Co-authored-by: Illviljan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent d8c3b1a commit 1ab02b4

14 files changed

+435
-89
lines changed

doc/api-hidden.rst

+29-27
Original file line numberDiff line numberDiff line change
@@ -351,33 +351,35 @@
351351
IndexVariable.values
352352

353353

354-
namedarray.core.NamedArray.all
355-
namedarray.core.NamedArray.any
356-
namedarray.core.NamedArray.attrs
357-
namedarray.core.NamedArray.chunks
358-
namedarray.core.NamedArray.chunksizes
359-
namedarray.core.NamedArray.copy
360-
namedarray.core.NamedArray.count
361-
namedarray.core.NamedArray.cumprod
362-
namedarray.core.NamedArray.cumsum
363-
namedarray.core.NamedArray.data
364-
namedarray.core.NamedArray.dims
365-
namedarray.core.NamedArray.dtype
366-
namedarray.core.NamedArray.get_axis_num
367-
namedarray.core.NamedArray.max
368-
namedarray.core.NamedArray.mean
369-
namedarray.core.NamedArray.median
370-
namedarray.core.NamedArray.min
371-
namedarray.core.NamedArray.nbytes
372-
namedarray.core.NamedArray.ndim
373-
namedarray.core.NamedArray.prod
374-
namedarray.core.NamedArray.reduce
375-
namedarray.core.NamedArray.shape
376-
namedarray.core.NamedArray.size
377-
namedarray.core.NamedArray.sizes
378-
namedarray.core.NamedArray.std
379-
namedarray.core.NamedArray.sum
380-
namedarray.core.NamedArray.var
354+
NamedArray.all
355+
NamedArray.any
356+
NamedArray.attrs
357+
NamedArray.broadcast_to
358+
NamedArray.chunks
359+
NamedArray.chunksizes
360+
NamedArray.copy
361+
NamedArray.count
362+
NamedArray.cumprod
363+
NamedArray.cumsum
364+
NamedArray.data
365+
NamedArray.dims
366+
NamedArray.dtype
367+
NamedArray.expand_dims
368+
NamedArray.get_axis_num
369+
NamedArray.max
370+
NamedArray.mean
371+
NamedArray.median
372+
NamedArray.min
373+
NamedArray.nbytes
374+
NamedArray.ndim
375+
NamedArray.prod
376+
NamedArray.reduce
377+
NamedArray.shape
378+
NamedArray.size
379+
NamedArray.sizes
380+
NamedArray.std
381+
NamedArray.sum
382+
NamedArray.var
381383

382384

383385
plot.plot

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ v2024.02.0 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and :py:meth:`NamedArray.broadcast_to`
27+
(:pull:`8380`) By `Anderson Banihirwe <https://github.com/andersy005>`_.
28+
2629
- Xarray now defers to flox's `heuristics <https://flox.readthedocs.io/en/latest/implementation.html#heuristics>`_
2730
to set default `method` for groupby problems. This only applies to ``flox>=0.9``.
2831
By `Deepak Cherian <https://github.com/dcherian>`_.

xarray/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from xarray.core.options import get_options, set_options
4242
from xarray.core.parallel import map_blocks
4343
from xarray.core.variable import IndexVariable, Variable, as_variable
44+
from xarray.namedarray.core import NamedArray
4445
from xarray.util.print_versions import show_versions
4546

4647
try:
@@ -104,6 +105,7 @@
104105
"IndexSelResult",
105106
"IndexVariable",
106107
"Variable",
108+
"NamedArray",
107109
# Exceptions
108110
"MergeError",
109111
"SerializationWarning",

xarray/core/dataarray.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
as_compatible_data,
7272
as_variable,
7373
)
74+
from xarray.namedarray.utils import infix_dims
7475
from xarray.plot.accessor import DataArrayPlotAccessor
7576
from xarray.plot.utils import _get_units_from_attrs
7677
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
@@ -3017,7 +3018,7 @@ def transpose(
30173018
Dataset.transpose
30183019
"""
30193020
if dims:
3020-
dims = tuple(utils.infix_dims(dims, self.dims, missing_dims))
3021+
dims = tuple(infix_dims(dims, self.dims, missing_dims))
30213022
variable = self.variable.transpose(*dims)
30223023
if transpose_coords:
30233024
coords: dict[Hashable, Variable] = {}

xarray/core/dataset.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@
114114
drop_dims_from_indexers,
115115
either_dict_or_kwargs,
116116
emit_user_level_warning,
117-
infix_dims,
118-
is_dict_like,
119117
is_scalar,
120118
maybe_wrap_array,
121119
)
@@ -126,6 +124,7 @@
126124
broadcast_variables,
127125
calculate_dimensions,
128126
)
127+
from xarray.namedarray.utils import infix_dims, is_dict_like
129128
from xarray.plot.accessor import DatasetPlotAccessor
130129
from xarray.util.deprecation_helpers import _deprecate_positional_args
131130

xarray/core/groupby.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections.abc import Hashable, Iterator, Mapping, Sequence
77
from dataclasses import dataclass, field
8-
from typing import (
9-
TYPE_CHECKING,
10-
Any,
11-
Callable,
12-
Generic,
13-
Literal,
14-
Union,
15-
)
8+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union
169

1710
import numpy as np
1811
import pandas as pd

xarray/core/utils.py

-30
Original file line numberDiff line numberDiff line change
@@ -846,36 +846,6 @@ def __len__(self) -> int:
846846
return len(self._data) - num_hidden
847847

848848

849-
def infix_dims(
850-
dims_supplied: Collection,
851-
dims_all: Collection,
852-
missing_dims: ErrorOptionsWithWarn = "raise",
853-
) -> Iterator:
854-
"""
855-
Resolves a supplied list containing an ellipsis representing other items, to
856-
a generator with the 'realized' list of all items
857-
"""
858-
if ... in dims_supplied:
859-
if len(set(dims_all)) != len(dims_all):
860-
raise ValueError("Cannot use ellipsis with repeated dims")
861-
if list(dims_supplied).count(...) > 1:
862-
raise ValueError("More than one ellipsis supplied")
863-
other_dims = [d for d in dims_all if d not in dims_supplied]
864-
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
865-
for d in existing_dims:
866-
if d is ...:
867-
yield from other_dims
868-
else:
869-
yield d
870-
else:
871-
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
872-
if set(existing_dims) ^ set(dims_all):
873-
raise ValueError(
874-
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
875-
)
876-
yield from existing_dims
877-
878-
879849
def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
880850
"""Get an new dimension name based on new_dim, that is not used in dims.
881851
If the same name exists, we add an underscore(s) in the head.

xarray/core/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
drop_dims_from_indexers,
4343
either_dict_or_kwargs,
4444
ensure_us_time_resolution,
45-
infix_dims,
4645
is_duck_array,
4746
maybe_coerce_to_str,
4847
)
4948
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
49+
from xarray.namedarray.utils import infix_dims
5050

5151
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
5252
indexing.ExplicitlyIndexed,

xarray/namedarray/_array_api.py

+30
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from xarray.namedarray._typing import (
1010
Default,
1111
_arrayapi,
12+
_Axes,
1213
_Axis,
1314
_default,
1415
_Dim,
@@ -196,3 +197,32 @@ def expand_dims(
196197
d.insert(axis, dim)
197198
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
198199
return out
200+
201+
202+
def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]:
203+
"""
204+
Permutes the dimensions of an array.
205+
206+
Parameters
207+
----------
208+
x :
209+
Array to permute.
210+
axes :
211+
Permutation of the dimensions of x.
212+
213+
Returns
214+
-------
215+
out :
216+
An array with permuted dimensions. The returned array must have the same
217+
data type as x.
218+
219+
"""
220+
221+
dims = x.dims
222+
new_dims = tuple(dims[i] for i in axes)
223+
if isinstance(x._data, _arrayapi):
224+
xp = _get_data_namespace(x)
225+
out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes))
226+
else:
227+
out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined]
228+
return out

xarray/namedarray/_typing.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Any,
88
Callable,
99
Final,
10+
Literal,
1011
Protocol,
1112
SupportsIndex,
1213
TypeVar,
@@ -311,8 +312,10 @@ def todense(self) -> np.ndarray[Any, _DType_co]:
311312

312313
# NamedArray can most likely use both __array_function__ and __array_namespace__:
313314
_sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi)
314-
315315
sparseduckarray = Union[
316316
_sparsearrayfunction[_ShapeType_co, _DType_co],
317317
_sparsearrayapi[_ShapeType_co, _DType_co],
318318
]
319+
320+
ErrorOptions = Literal["raise", "ignore"]
321+
ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"]

0 commit comments

Comments
 (0)