2121
2222import numpy as np
2323import pytest
24-
2524import tvm
2625import tvm .testing
27-
2826from tvm import relay , te
2927from tvm .error import TVMError
3028from tvm .relay import create_executor , transform
3129from tvm .relay .testing import check_grad , run_infer_type
3230
3331from utils import ref_funcs
3432
35-
3633executor_kind = tvm .testing .parameter ("graph" , "debug" )
3734
3835
@@ -426,31 +423,36 @@ def test_take(self, dshape, indices_shape, oshape, axis):
426423
427424
428425class TestTake :
429- src_shape , indices_src , axis , mode = tvm .testing .parameters (
430- ((4 ,), [1 ], None , "clip" ),
431- ((4 ,), [[0 , 1 , 2 , 3 ]], None , "clip" ),
432- ((3 , 3 , 3 ), [[11 , 25 ]], None , "clip" ),
433- ((4 ,), [[0 , 1 ], [2 , 3 ]], None , "clip" ),
434- ((4 ,), [1 ], 0 , "clip" ),
435- ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 0 , "clip" ),
436- ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 1 , "clip" ),
437- ((4 , 3 , 5 , 6 ), [[2 , 1 , 0 , 0 ]], - 2 , "clip" ),
438- ((3 , 4 ), [- 5 , 20 ], None , "clip" ),
439- ((3 , 4 ), [- 5 , 20 ], None , "wrap" ),
440- ((3 , 4 ), [- 1 , 2 ], 0 , "clip" ),
441- ((3 , 4 ), [- 1 , 2 ], 0 , "wrap" ),
442- ((3 , 4 ), [- 1 , 2 ], 1 , "clip" ),
443- ((3 , 4 ), [- 1 , 2 ], 1 , "wrap" ),
444- ((3 , 3 , 3 ), [[11 , 25 ]], None , "fast" ),
445- ((3 , 4 ), [0 , 2 ], 0 , "fast" ),
446- ((3 , 4 ), [0 , 2 ], 1 , "fast" ),
426+ src_shape , indices_src , axis , mode , indices_dtype = tvm .testing .parameters (
427+ ((4 ,), [1 ], None , "clip" , "int32" ),
428+ ((4 ,), [[0 , 1 , 2 , 3 ]], None , "clip" , "int32" ),
429+ ((3 , 3 , 3 ), [[11 , 25 ]], None , "clip" , "int32" ),
430+ ((4 ,), [[0 , 1 ], [2 , 3 ]], None , "clip" , "int32" ),
431+ ((4 ,), [1 ], 0 , "clip" , "int32" ),
432+ ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 0 , "clip" , "int32" ),
433+ ((2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 1 , "clip" , "int32" ),
434+ ((4 , 3 , 5 , 6 ), [[2 , 1 , 0 , 0 ]], - 2 , "clip" , "int32" ),
435+ ((3 , 4 ), [- 5 , 20 ], None , "clip" , "int32" ),
436+ ((3 , 4 ), [- 5 , 20 ], None , "wrap" , "int32" ),
437+ ((3 , 4 ), [- 1 , 2 ], 0 , "clip" , "int32" ),
438+ ((3 , 4 ), [- 1 , 2 ], 0 , "wrap" , "int32" ),
439+ ((3 , 4 ), [- 1 , 2 ], 1 , "clip" , "int32" ),
440+ ((3 , 4 ), [- 1 , 2 ], 1 , "wrap" , "int32" ),
441+ ((3 , 3 , 3 ), [[11 , 25 ]], None , "fast" , "int32" ),
442+ ((3 , 4 ), [0 , 2 ], 0 , "fast" , "int32" ),
443+ ((3 , 4 ), [0 , 2 ], 1 , "fast" , "int32" ),
444+ ((3 , 4 ), [1 , 2 ], 1 , "clip" , "uint32" ),
445+ ((3 , 4 ), [1 , 2 ], 1 , "wrap" , "uint16" ),
446+ ((3 , 3 , 3 ), [1 , 2 ], None , "fast" , "uint16" ),
447+ ((3 , 4 ), [0 , 2 ], 0 , "fast" , "uint8" ),
447448 )
448449
449450 # Incorrect numeric output in some cases on vulkan
450451 @tvm .testing .known_failing_targets ("vulkan" )
451- def test_take (self , target , dev , executor_kind , src_shape , indices_src , axis , mode ):
452+ def test_take (
453+ self , target , dev , executor_kind , src_shape , indices_src , axis , mode , indices_dtype
454+ ):
452455 src_dtype = "float32"
453- indices_dtype = "int32"
454456 indices_src = np .array (indices_src , dtype = indices_dtype )
455457 x = relay .var ("x" , relay .TensorType (src_shape , src_dtype ))
456458 indices = relay .var ("indices" , relay .TensorType (indices_src .shape , indices_dtype ))
@@ -459,11 +461,16 @@ def test_take(self, target, dev, executor_kind, src_shape, indices_src, axis, mo
459461 func = relay .Function ([x , indices ], z )
460462 x_data = np .random .uniform (low = - 1 , high = 1 , size = src_shape ).astype (src_dtype )
461463 np_mode = "raise" if mode == "fast" else mode
462- ref_res = np .take (x_data , indices = indices_src , axis = axis , mode = np_mode )
463464
464465 op_res = relay .create_executor (executor_kind , device = dev , target = target ).evaluate (func )(
465466 x_data , indices_src
466467 )
468+
469+ # Old versions of numpy has take internally cast inside take which may violate
470+ # safety rules. We have such version in i386 CI image.
471+ indices_src = indices_src .astype ("int32" )
472+ ref_res = np .take (x_data , indices = indices_src , axis = axis , mode = np_mode )
473+
467474 tvm .testing .assert_allclose (op_res .numpy (), ref_res , rtol = 1e-5 )
468475
469476
@@ -1267,12 +1274,12 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype):
12671274 ],
12681275)
12691276def test_gather (target , dev , executor_kind , data , axis , indices , ref_res ):
1270- def verify_gather (data , axis , indices , ref_res ):
1277+ def verify_gather (data , axis , indices , ref_res , indices_dtype = "int32" ):
12711278 data = np .asarray (data , dtype = "float32" )
1272- indices = np .asarray (indices , dtype = "int32" )
1279+ indices = np .asarray (indices , dtype = indices_dtype )
12731280 ref_res = np .asarray (ref_res )
12741281 d = relay .var ("x" , relay .TensorType (data .shape , "float32" ))
1275- i = relay .var ("y" , relay .TensorType (indices .shape , "int32" ))
1282+ i = relay .var ("y" , relay .TensorType (indices .shape , indices_dtype ))
12761283 z = relay .gather (d , axis , i )
12771284
12781285 func = relay .Function ([d , i ], z )
@@ -1283,22 +1290,25 @@ def verify_gather(data, axis, indices, ref_res):
12831290 tvm .testing .assert_allclose (op_res .numpy (), ref_res , rtol = 1e-5 )
12841291
12851292 verify_gather (data , axis , indices , ref_res )
1293+ verify_gather (data , axis , indices , ref_res , indices_dtype = "uint32" )
1294+
1295+ verify_gather (data , axis , indices , ref_res )
12861296
12871297
12881298def test_gather_nd (target , dev , executor_kind ):
1289- def verify_gather_nd (xshape , yshape , y_data , batch_dims = 0 ):
1299+ def verify_gather_nd (xshape , yshape , y_data , batch_dims = 0 , indices_dtype = "int32" ):
12901300 x = relay .var ("x" , relay .TensorType (xshape , "float32" ))
1291- y = relay .var ("y" , relay .TensorType (yshape , "int32" ))
1301+ y = relay .var ("y" , relay .TensorType (yshape , indices_dtype ))
12921302 z = relay .gather_nd (x , y , batch_dims )
12931303
12941304 func = relay .Function ([x , y ], z )
12951305
12961306 x_data = np .random .uniform (size = xshape ).astype ("float32" )
12971307
12981308 if y_data :
1299- y_data = np .array (y_data , dtype = "int32" )
1309+ y_data = np .array (y_data , dtype = indices_dtype )
13001310 else :
1301- y_data = np .random .randint (low = 0 , high = 2 , size = yshape , dtype = "int32" )
1311+ y_data = np .random .randint (low = 0 , high = 2 , size = yshape , dtype = indices_dtype )
13021312
13031313 ref_res = ref_funcs .gather_nd (x_data , y_data , batch_dims )
13041314
@@ -1335,6 +1345,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
13351345 verify_gather_nd ((3 , 2 , 2 , 3 , 4 ), (2 , 3 , 2 , 2 ), None , 2 )
13361346 verify_gather_nd ((3 , 2 , 2 , 3 , 4 ), (1 , 3 , 2 , 3 ), None , 2 )
13371347
1348+ verify_gather_nd ((3 , 2 , 2 , 3 , 4 ), (1 , 3 , 2 , 3 ), None , 2 , indices_dtype = "uint8" )
1349+ verify_gather_nd ((2 , 2 , 2 ), (2 , 2 , 1 ), [[[1 ], [0 ]], [[0 ], [1 ]]], 1 , indices_dtype = "uint32" )
1350+
13381351
13391352def _verify_infiniteness_ops (relay_op , ref_op ):
13401353 for dtype in ["float32" , "float16" , "float16" , "int32" , "int16" ]:
0 commit comments