Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add count_nonzero to specification #803

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/draft/API_specification/searching_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Objects in API

argmax
argmin
count_nonzero
nonzero
searchsorted
where
47 changes: 38 additions & 9 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]
__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"]


from ._types import Optional, Tuple, Literal, array
from ._types import Optional, Tuple, Literal, Union, array


def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -54,15 +54,41 @@ def argmin(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
"""


def nonzero(x: array, /) -> Tuple[array, ...]:
def count_nonzero(
x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> array:
"""
Returns the indices of the array elements which are non-zero.
Counts the number of array elements which are non-zero.

.. note::
If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero.
Parameters
----------
x: array
input array.
axis: Optional[Union[int, Tuple[int, ...]]]
axis or axes along which to count non-zero values. By default, the number of non-zero values must be computed over the entire array. If a tuple of integers, the number of non-zero values must be computed over multiple axes. Default: ``None``.
keepdims: bool
if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``.

.. note::
If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``.
Returns
-------
out: array
if the number of non-zeros values was computed over the entire array, a zero-dimensional array containing the total number of non-zero values; otherwise, a non-zero-dimensional array containing the counts along the specified axes. The returned array must have the default array index data type.

Notes
-----

- If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero.
- If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``.
"""


def nonzero(x: array, /) -> Tuple[array, ...]:
"""
Returns the indices of the array elements which are non-zero.

.. admonition:: Data-dependent output shape
:class: admonition important
Expand All @@ -76,12 +102,15 @@ def nonzero(x: array, /) -> Tuple[array, ...]:

Returns
-------
out: Typle[array, ...]
out: Tuple[array, ...]
a tuple of ``k`` arrays, one for each dimension of ``x`` and each of size ``n`` (where ``n`` is the total number of non-zero elements), containing the indices of the non-zero elements in that dimension. The indices must be returned in row-major, C-style order. The returned array must have the default array index data type.

Notes
-----

- If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero.
- If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``.

.. versionchanged:: 2022.12
Added complex data type support.
"""
Expand Down
Loading