Skip to content

Commit 0b51fa3

Browse files
committed
add isin tests
1 parent 267f8a9 commit 0b51fa3

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

dpnp/tests/test_logic.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,115 @@ def test_array_equal_nan(a):
795795
result = dpnp.array_equal(dpnp.array(a), dpnp.array(b), equal_nan=True)
796796
expected = numpy.array_equal(a, b, equal_nan=True)
797797
assert_equal(result, expected)
798+
799+
800+
@pytest.mark.parametrize(
801+
"a",
802+
[
803+
numpy.array([1, 2, 3, 4]),
804+
numpy.array([[1, 2], [3, 4]]),
805+
],
806+
)
807+
@pytest.mark.parametrize(
808+
"b",
809+
[
810+
numpy.array([2, 4, 6]),
811+
numpy.array([[1, 3], [5, 7]]),
812+
],
813+
)
814+
def test_isin_basic(a, b):
815+
dp_a = dpnp.array(a)
816+
dp_b = dpnp.array(b)
817+
818+
expected = numpy.isin(a, b)
819+
result = dpnp.isin(dp_a, dp_b)
820+
assert_equal(result, expected)
821+
822+
823+
@pytest.mark.parametrize("dtype", get_all_dtypes())
824+
def test_isin_dtype(dtype):
825+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826+
b = numpy.array([2, 4], dtype=dtype)
827+
828+
dp_a = dpnp.array(a, dtype=dtype)
829+
dp_b = dpnp.array(b, dtype=dtype)
830+
831+
expected = numpy.isin(a, b)
832+
result = dpnp.isin(dp_a, dp_b)
833+
assert_equal(result, expected)
834+
835+
836+
@pytest.mark.parametrize("sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))])
837+
def test_isin_broadcast(sh_a, sh_b):
838+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
839+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
840+
841+
dp_a = dpnp.array(a)
842+
dp_b = dpnp.array(b)
843+
844+
expected = numpy.isin(a, b)
845+
result = dpnp.isin(dp_a, dp_b)
846+
assert_equal(result, expected)
847+
848+
849+
def test_isin_scalar_elements():
850+
a = numpy.array([1, 2, 3])
851+
b = 2
852+
853+
dp_a = dpnp.array(a)
854+
dp_b = dpnp.array(b)
855+
856+
expected = numpy.isin(a, b)
857+
result = dpnp.isin(dp_a, dp_b)
858+
assert_equal(result, expected)
859+
860+
861+
def test_isin_scalar_test_elements():
862+
a = 2
863+
b = numpy.array([1, 2, 3])
864+
865+
dp_a = dpnp.array(a)
866+
dp_b = dpnp.array(b)
867+
868+
expected = numpy.isin(a, b)
869+
result = dpnp.isin(dp_a, dp_b)
870+
assert_equal(result, expected)
871+
872+
873+
def test_isin_empty():
874+
a = numpy.array([], dtype=int)
875+
b = numpy.array([1, 2, 3])
876+
877+
dp_a = dpnp.array(a)
878+
dp_b = dpnp.array(b)
879+
880+
expected = numpy.isin(a, b)
881+
result = dpnp.isin(dp_a, dp_b)
882+
assert_equal(result, expected)
883+
884+
885+
def test_isin_out_kwarg():
886+
a = numpy.array([1, 2, 3, 4])
887+
b = numpy.array([2, 4])
888+
889+
dp_a = dpnp.array(a)
890+
dp_b = dpnp.array(b)
891+
892+
expected = numpy.isin(a, b)
893+
out = dpnp.empty(expected.shape, dtype=dpnp.bool)
894+
result = dpnp.isin(dp_a, dp_b, out=out)
895+
896+
assert result is out
897+
assert_equal(result, expected)
898+
899+
900+
def test_isin_errors():
901+
a = dpnp.arange(5)
902+
b = dpnp.arange(3)
903+
904+
# unsupported type for elements or test_elements
905+
with pytest.raises(TypeError):
906+
dpnp.isin(dict(), b)
907+
908+
with pytest.raises(TypeError):
909+
dpnp.isin(a, dict())

0 commit comments

Comments
 (0)