Skip to content

Commit f2bbbf3

Browse files
author
Vahid Tavanashad
committed
update dpnp.size to accept tuple of ints for axes
1 parent 9fc84dc commit f2bbbf3

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
import dpctl
4747
import dpctl.tensor as dpt
4848
import numpy
49-
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_index
49+
from dpctl.tensor._numpy_helper import (
50+
AxisError,
51+
normalize_axis_index,
52+
normalize_axis_tuple,
53+
)
5054

5155
import dpnp
5256

@@ -3528,8 +3532,8 @@ def size(a, axis=None):
35283532
----------
35293533
a : array_like
35303534
Input data.
3531-
axis : {None, int}, optional
3532-
Axis along which the elements are counted.
3535+
axis : {None, int, tuple of ints}, optional
3536+
Axis or axes along which the elements are counted.
35333537
By default, give the total number of elements.
35343538
35353539
Default: ``None``.
@@ -3551,23 +3555,21 @@ def size(a, axis=None):
35513555
>>> a = [[1, 2, 3], [4, 5, 6]]
35523556
>>> np.size(a)
35533557
6
3554-
>>> np.size(a, 1)
3558+
>>> np.size(a, axis=1)
35553559
3
3556-
>>> np.size(a, 0)
3560+
>>> np.size(a, axis=0)
35573561
2
3558-
3559-
>>> a = np.asarray(a)
3560-
>>> np.size(a)
3562+
>>> np.size(a, axis=(0, 1))
35613563
6
3562-
>>> np.size(a, 1)
3563-
3
35643564
35653565
"""
35663566

35673567
if dpnp.is_supported_array_type(a):
35683568
if axis is None:
35693569
return a.size
3570-
return a.shape[axis]
3570+
_shape = a.shape
3571+
_axis = normalize_axis_tuple(axis, a.ndim)
3572+
return math.prod(_shape[ax] for ax in _axis)
35713573

35723574
return numpy.size(a, axis)
35733575

dpnp/tests/test_manipulation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def test_ndim():
7474
assert dpnp.ndim(ia) == exp
7575

7676

77+
# TODO: include commented code in the test when numpy-2.4 is released
78+
# @testing.with_requires("numpy>=2.4")
7779
def test_size():
7880
a = [[1, 2, 3], [4, 5, 6]]
7981
ia = dpnp.array(a)
@@ -87,6 +89,12 @@ def test_size():
8789
assert dpnp.size(a, 0) == exp
8890
assert dpnp.size(ia, 0) == exp
8991

92+
assert dpnp.size(ia, 1) == numpy.size(a, 1)
93+
assert dpnp.size(ia, ()) == 1 # numpy.size(a, ())
94+
assert dpnp.size(ia, (0,)) == 2 # numpy.size(a, (0,))
95+
assert dpnp.size(ia, (1,)) == 3 # numpy.size(a, (1,))
96+
assert dpnp.size(ia, (0, 1)) == 6 # numpy.size(a, (0, 1))
97+
9098

9199
class TestAppend:
92100
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)