Skip to content

Commit ca02227

Browse files
authored
Merge 77d5512 into 131c490
2 parents 131c490 + 77d5512 commit ca02227

File tree

5 files changed

+211
-1
lines changed

5 files changed

+211
-1
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"iscomplexobj",
7070
"isfinite",
7171
"isfortran",
72+
"isin",
7273
"isinf",
7374
"isnan",
7475
"isneginf",
@@ -1196,6 +1197,120 @@ def isfortran(a):
11961197
return a.flags.fnc
11971198

11981199

1200+
def isin(
1201+
element,
1202+
test_elements,
1203+
assume_unique=False, # pylint: disable=unused-argument
1204+
invert=False,
1205+
*,
1206+
kind=None, # pylint: disable=unused-argument
1207+
):
1208+
"""
1209+
Calculates ``element in test_elements``, broadcasting over `element` only.
1210+
Returns a boolean array of the same shape as `element` that is ``True``
1211+
where an element of `element` is in `test_elements` and ``False``
1212+
otherwise.
1213+
1214+
For full documentation refer to :obj:`numpy.isin`.
1215+
1216+
Parameters
1217+
----------
1218+
element : {dpnp.ndarray, usm_ndarray, scalar}
1219+
Input array.
1220+
test_elements : {dpnp.ndarray, usm_ndarray, scalar}
1221+
The values against which to test each value of `element`.
1222+
This argument is flattened if it is an array.
1223+
assume_unique : bool, optional
1224+
Ignored, as no performance benefit is gained by assuming the
1225+
input arrays are unique. Included for compatibility with NumPy.
1226+
1227+
Default: ``False``.
1228+
invert : bool, optional
1229+
If ``True``, the values in the returned array are inverted, as if
1230+
calculating ``element not in test_elements``.
1231+
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
1232+
than) ``dpnp.invert(dpnp.isin(a, b))``.
1233+
1234+
Default: ``False``.
1235+
kind : {None, "sort"}, optional
1236+
Ignored, as the only algorithm implemented is ``"sort"``. Included for
1237+
compatibility with NumPy.
1238+
1239+
Default: ``None``.
1240+
1241+
Returns
1242+
-------
1243+
isin : dpnp.ndarray of bool dtype
1244+
Has the same shape as `element`. The values `element[isin]`
1245+
are in `test_elements`.
1246+
1247+
Examples
1248+
--------
1249+
>>> import dpnp as np
1250+
>>> element = 2*np.arange(4).reshape((2, 2))
1251+
>>> element
1252+
array([[0, 2],
1253+
[4, 6]])
1254+
>>> test_elements = [1, 2, 4, 8]
1255+
>>> mask = np.isin(element, test_elements)
1256+
>>> mask
1257+
array([[False, True],
1258+
[ True, False]])
1259+
>>> element[mask]
1260+
array([2, 4])
1261+
1262+
The indices of the matched values can be obtained with `nonzero`:
1263+
1264+
>>> np.nonzero(mask)
1265+
(array([0, 1]), array([1, 0]))
1266+
1267+
The test can also be inverted:
1268+
1269+
>>> mask = np.isin(element, test_elements, invert=True)
1270+
>>> mask
1271+
array([[ True, False],
1272+
[False, True]])
1273+
>>> element[mask]
1274+
array([0, 6])
1275+
1276+
"""
1277+
1278+
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1279+
if dpnp.isscalar(element):
1280+
usm_element = dpnp.as_usm_ndarray(
1281+
element,
1282+
usm_type=test_elements.usm_type,
1283+
sycl_queue=test_elements.sycl_queue,
1284+
)
1285+
usm_test = dpnp.get_usm_ndarray(test_elements)
1286+
elif dpnp.isscalar(test_elements):
1287+
usm_test = dpnp.as_usm_ndarray(
1288+
test_elements,
1289+
usm_type=element.usm_type,
1290+
sycl_queue=element.sycl_queue,
1291+
)
1292+
usm_element = dpnp.get_usm_ndarray(element)
1293+
else:
1294+
if (
1295+
dpu.get_execution_queue(
1296+
(element.sycl_queue, test_elements.sycl_queue)
1297+
)
1298+
is None
1299+
):
1300+
raise dpu.ExecutionPlacementError(
1301+
"Input arrays have incompatible allocation queues"
1302+
)
1303+
usm_element = dpnp.get_usm_ndarray(element)
1304+
usm_test = dpnp.get_usm_ndarray(test_elements)
1305+
return dpnp.get_result_array(
1306+
dpt.isin(
1307+
usm_element,
1308+
usm_test,
1309+
invert=invert,
1310+
)
1311+
)
1312+
1313+
11991314
_ISINF_DOCSTRING = """
12001315
Tests each element :math:`x_i` of the input array `x` to determine if equal to
12011316
positive or negative infinity.

dpnp/tests/test_logic.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,97 @@ def test_array_equal_nan(a):
795795
result = dpnp.array_equal(dpnp.array(a), dpnp.array(b), equal_nan=True)
796796
expected = numpy.array_equal(a, b, equal_nan=True)
797797
assert_equal(result, expected)
798+
799+
800+
class TestIsin:
801+
@pytest.mark.parametrize(
802+
"a",
803+
[
804+
numpy.array([1, 2, 3, 4]),
805+
numpy.array([[1, 2], [3, 4]]),
806+
],
807+
)
808+
@pytest.mark.parametrize(
809+
"b",
810+
[
811+
numpy.array([2, 4, 6]),
812+
numpy.array([[1, 3], [5, 7]]),
813+
],
814+
)
815+
def test_isin_basic(a, b):
816+
dp_a = dpnp.array(a)
817+
dp_b = dpnp.array(b)
818+
819+
expected = numpy.isin(a, b)
820+
result = dpnp.isin(dp_a, dp_b)
821+
assert_equal(result, expected)
822+
823+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
824+
def test_isin_dtype(dtype):
825+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826+
b = numpy.array([2, 4], dtype=dtype)
827+
828+
dp_a = dpnp.array(a, dtype=dtype)
829+
dp_b = dpnp.array(b, dtype=dtype)
830+
831+
expected = numpy.isin(a, b)
832+
result = dpnp.isin(dp_a, dp_b)
833+
assert_equal(result, expected)
834+
835+
@pytest.mark.parametrize(
836+
"sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))]
837+
)
838+
def test_isin_broadcast(sh_a, sh_b):
839+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
840+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
841+
842+
dp_a = dpnp.array(a)
843+
dp_b = dpnp.array(b)
844+
845+
expected = numpy.isin(a, b)
846+
result = dpnp.isin(dp_a, dp_b)
847+
assert_equal(result, expected)
848+
849+
def test_isin_scalar_elements():
850+
a = numpy.array([1, 2, 3])
851+
b = 2
852+
853+
dp_a = dpnp.array(a)
854+
dp_b = dpnp.array(b)
855+
856+
expected = numpy.isin(a, b)
857+
result = dpnp.isin(dp_a, dp_b)
858+
assert_equal(result, expected)
859+
860+
def test_isin_scalar_test_elements():
861+
a = 2
862+
b = numpy.array([1, 2, 3])
863+
864+
dp_a = dpnp.array(a)
865+
dp_b = dpnp.array(b)
866+
867+
expected = numpy.isin(a, b)
868+
result = dpnp.isin(dp_a, dp_b)
869+
assert_equal(result, expected)
870+
871+
def test_isin_empty():
872+
a = numpy.array([], dtype=int)
873+
b = numpy.array([1, 2, 3])
874+
875+
dp_a = dpnp.array(a)
876+
dp_b = dpnp.array(b)
877+
878+
expected = numpy.isin(a, b)
879+
result = dpnp.isin(dp_a, dp_b)
880+
assert_equal(result, expected)
881+
882+
def test_isin_errors():
883+
a = dpnp.arange(5)
884+
b = dpnp.arange(3)
885+
886+
# unsupported type for elements or test_elements
887+
with pytest.raises(TypeError):
888+
dpnp.isin(dict(), b)
889+
890+
with pytest.raises(TypeError):
891+
dpnp.isin(a, dict())

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def test_logic_op_1in(op, device):
536536
"greater",
537537
"greater_equal",
538538
"isclose",
539+
"isin",
539540
"less",
540541
"less_equal",
541542
"logical_and",

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def test_logic_op_1in(op, usm_type_x):
355355
"greater",
356356
"greater_equal",
357357
"isclose",
358+
"isin",
358359
"less",
359360
"less_equal",
360361
"logical_and",

dpnp/tests/third_party/cupy/logic_tests/test_truth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def test_with_out(self, xp, dtype):
8989
return out
9090

9191

92-
@pytest.mark.skip("isin() is not supported yet")
9392
@testing.parameterize(
9493
*testing.product(
9594
{

0 commit comments

Comments
 (0)