Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ Bug fixes
- Allow converting :py:class:`Dataset` or :py:class:`DataArray` objects with a ``MultiIndex``
and at least one other dimension to a ``pandas`` object (:issue:`3008`, :pull:`4442`).
By `ghislainp <https://github.com/ghislainp>`_.
- don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`).
By `Justus Magin <https://github.com/keewis>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
45 changes: 20 additions & 25 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4656,6 +4656,12 @@ def reduce(
Dataset with this object's DataArrays replaced with new DataArrays
of summarized data and the indicated dimension(s) removed.
"""
if "axis" in kwargs:
raise ValueError(
"passing 'axis' to Dataset reduce methods is ambiguous."
" Please use 'dim' instead."
)

if dim is None or dim is ...:
dims = set(self.dims)
elif isinstance(dim, str) or not isinstance(dim, Iterable):
Expand Down Expand Up @@ -6845,7 +6851,7 @@ def idxmax(
)
)

def argmin(self, dim=None, axis=None, **kwargs):
def argmin(self, dim=None, **kwargs):
"""Indices of the minima of the member variables.

If there are multiple minima, the indices of the first one found will be
Expand All @@ -6859,9 +6865,6 @@ def argmin(self, dim=None, axis=None, **kwargs):
this is deprecated, in future will be an error, since DataArray.argmin will
return a dict with indices for all dimensions, which does not make sense for
a Dataset.
axis : int, optional
Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments
can be supplied.
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
Expand All @@ -6879,36 +6882,33 @@ def argmin(self, dim=None, axis=None, **kwargs):
See Also
--------
DataArray.argmin

"""
if dim is None and axis is None:
if dim is None:
warnings.warn(
"Once the behaviour of DataArray.argmin() and Variable.argmin() with "
"neither dim nor axis argument changes to return a dict of indices of "
"each dimension, for consistency it will be an error to call "
"Dataset.argmin() with no argument, since we don't return a dict of "
"Datasets.",
"Once the behaviour of DataArray.argmin() and Variable.argmin() without "
"dim changes to return a dict of indices of each dimension, for "
"consistency it will be an error to call Dataset.argmin() with no argument,"
"since we don't return a dict of Datasets.",
DeprecationWarning,
stacklevel=2,
)
if (
dim is None
or axis is not None
or (not isinstance(dim, Sequence) and dim is not ...)
or isinstance(dim, str)
):
# Return int index if single dimension is passed, and is not part of a
# sequence
argmin_func = getattr(duck_array_ops, "argmin")
return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs)
return self.reduce(argmin_func, dim=dim, **kwargs)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
"dicts cannot be contained in a Dataset, so cannot call "
"Dataset.argmin() with a sequence or ... for dim"
)

def argmax(self, dim=None, axis=None, **kwargs):
def argmax(self, dim=None, **kwargs):
"""Indices of the maxima of the member variables.

If there are multiple maxima, the indices of the first one found will be
Expand All @@ -6922,9 +6922,6 @@ def argmax(self, dim=None, axis=None, **kwargs):
this is deprecated, in future will be an error, since DataArray.argmax will
return a dict with indices for all dimensions, which does not make sense for
a Dataset.
axis : int, optional
Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments
can be supplied.
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
Expand All @@ -6944,26 +6941,24 @@ def argmax(self, dim=None, axis=None, **kwargs):
DataArray.argmax

"""
if dim is None and axis is None:
if dim is None:
warnings.warn(
"Once the behaviour of DataArray.argmax() and Variable.argmax() with "
"neither dim nor axis argument changes to return a dict of indices of "
"each dimension, for consistency it will be an error to call "
"Dataset.argmax() with no argument, since we don't return a dict of "
"Datasets.",
"Once the behaviour of DataArray.argmin() and Variable.argmin() without "
"dim changes to return a dict of indices of each dimension, for "
"consistency it will be an error to call Dataset.argmin() with no argument,"
"since we don't return a dict of Datasets.",
DeprecationWarning,
stacklevel=2,
)
if (
dim is None
or axis is not None
or (not isinstance(dim, Sequence) and dim is not ...)
or isinstance(dim, str)
):
# Return int index if single dimension is passed, and is not part of a
# sequence
argmax_func = getattr(duck_array_ops, "argmax")
return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs)
return self.reduce(argmax_func, dim=dim, **kwargs)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down
9 changes: 3 additions & 6 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4746,6 +4746,9 @@ def test_reduce(self):

assert_equal(data.mean(dim=[]), data)

with pytest.raises(ValueError):
data.mean(axis=0)

def test_reduce_coords(self):
# regression test for GH1470
data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4})
Expand Down Expand Up @@ -4926,9 +4929,6 @@ def mean_only_one_axis(x, axis):
with raises_regex(TypeError, "missing 1 required positional argument: 'axis'"):
ds.reduce(mean_only_one_axis)

with raises_regex(TypeError, "non-integer axis"):
ds.reduce(mean_only_one_axis, axis=["x", "y"])

def test_reduce_no_axis(self):
def total_sum(x):
return np.sum(x.flatten())
Expand All @@ -4938,9 +4938,6 @@ def total_sum(x):
actual = ds.reduce(total_sum)
assert_identical(expected, actual)

with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
ds.reduce(total_sum, axis=0)

with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
ds.reduce(total_sum, dim="x")

Expand Down
29 changes: 0 additions & 29 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3972,35 +3972,6 @@ def test_repr(self, func, variant, dtype):
@pytest.mark.parametrize(
"func",
(
function("all"),
function("any"),
pytest.param(
function("argmax"),
marks=pytest.mark.skip(
reason="calling np.argmax as a function on xarray objects is not "
"supported"
),
),
pytest.param(
function("argmin"),
marks=pytest.mark.skip(
reason="calling np.argmin as a function on xarray objects is not "
"supported"
),
),
function("max"),
function("min"),
function("mean"),
pytest.param(
function("median"),
marks=pytest.mark.xfail(reason="median does not work with dataset yet"),
),
function("sum"),
function("prod"),
function("std"),
function("var"),
function("cumsum"),
function("cumprod"),
method("all"),
method("any"),
method("argmax", dim="x"),
Expand Down