Skip to content

Commit

Permalink
cast numpy scalars to arrays in NamedArray.from_array (#10008)
Browse files Browse the repository at this point in the history
* check that aggregations result in array objects

* don't consider numpy scalars as arrays

* changelog [skip-ci]

* retrigger CI

* Update xarray/tests/test_namedarray.py

---------

Co-authored-by: Kai Mühlbauer <[email protected]>
  • Loading branch information
keewis and kmuehlbauer authored Jan 30, 2025
1 parent e84e421 commit f91306a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Bug fixes
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Fix weighted ``polyfit`` for arrays with more than two dimensions (:issue:`9972`, :pull:`9974`).
By `Mattia Almansi <https://github.com/malmans2>`_.
- Cast ``numpy`` scalars to arrays in :py:meth:`NamedArray.from_arrays` (:issue:`10005`, :pull:`10008`)
By `Justus Magin <https://github.com/keewis>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def from_array(

return NamedArray(dims, data, attrs)

if isinstance(data, _arrayfunction_or_api):
if isinstance(data, _arrayfunction_or_api) and not isinstance(data, np.generic):
return NamedArray(dims, data, attrs)

if isinstance(data, tuple):
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,13 @@ def test_warn_on_repeated_dimension_names(self) -> None:
with pytest.warns(UserWarning, match="Duplicate dimension names"):
NamedArray(("x", "x"), np.arange(4).reshape(2, 2))

def test_aggregation(self) -> None:
x: NamedArray[Any, np.dtype[np.int64]]
x = NamedArray(("x", "y"), np.arange(4).reshape(2, 2))

result = x.sum()
assert isinstance(result.data, np.ndarray)


def test_repr() -> None:
x: NamedArray[Any, np.dtype[np.uint64]]
Expand Down

0 comments on commit f91306a

Please sign in to comment.