@@ -154,5 +154,85 @@ def main(A: R.Tensor(["n", "n"], "float16")):
154154 tvm .testing .assert_allclose (tvm_output .numpy (), np_expected )
155155
156156
157+ @tvm .testing .parametrize_targets ("llvm" )
158+ def test_take_nan_mode_OOB_indices (target , dev , axis ):
159+ """Test R.take with mode="nan" and out-of-bounds indices.
160+ This test checks that out-of-bounds indices produce NaN values in the output tensor.
161+ """
162+
163+ @I .ir_module
164+ class Module :
165+ @R .function
166+ def main (A : R .Tensor ([3 , 3 ], "float16" )):
167+ output = R .take (A , R .const ([0 , 1 , 2 , 3 ]), axis = axis , mode = "nan" )
168+ return output
169+
170+ built = tvm .compile (Module , target = target )
171+ vm = tvm .relax .VirtualMachine (built , dev )
172+
173+ np_input = np .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 9.0 ]], dtype = "float16" )
174+ tvm_input = tvm .nd .array (np_input , dev )
175+ tvm_output = vm ["main" ](tvm_input )
176+ if axis == 0 :
177+ np_expected = np .array (
178+ [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ], [7.0 , 8.0 , 9.0 ], [np .nan , np .nan , np .nan ]],
179+ dtype = "float16" ,
180+ )
181+ elif axis == 1 :
182+ np_expected = np .array (
183+ [[1.0 , 2.0 , 3.0 , np .nan ], [4.0 , 5.0 , 6.0 , np .nan ], [7.0 , 8.0 , 9.0 , np .nan ]],
184+ dtype = "float16" ,
185+ )
186+
187+ tvm .testing .assert_allclose (tvm_output .numpy (), np_expected )
188+
189+
190+ @tvm .testing .parametrize_targets ("llvm" )
191+ def test_take_wrap_mode_OOB_indices (target , dev , axis ):
192+ """Test R.take with mode="wrap" and out-of-bounds indices.
193+ This test checks that out-of-bounds indices wrap around to the valid range.
194+ """
195+
196+ @I .ir_module
197+ class Module :
198+ @R .function
199+ def main (A : R .Tensor ([3 , 3 ], "float16" )):
200+ output = R .take (A , R .const ([0 , 1 , 2 , 3 ]), axis = axis , mode = "wrap" )
201+ return output
202+
203+ built = tvm .compile (Module , target = target )
204+ vm = tvm .relax .VirtualMachine (built , dev )
205+
206+ np_input = np .random .random (size = [3 , 3 ]).astype ("float16" )
207+ tvm_input = tvm .nd .array (np_input , dev )
208+ tvm_output = vm ["main" ](tvm_input )
209+ np_expected = np .take (np_input , [0 , 1 , 2 , 3 ], axis = axis , mode = "wrap" )
210+
211+ tvm .testing .assert_allclose (tvm_output .numpy (), np_expected )
212+
213+
214+ @tvm .testing .parametrize_targets ("llvm" )
215+ def test_take_clip_mode_OOB_indices (target , dev , axis ):
216+ """Test R.take with mode="clip" and out-of-bounds indices.
217+ This test checks that out-of-bounds indices are clipped to the valid range.
218+ """
219+
220+ @I .ir_module
221+ class Module :
222+ @R .function
223+ def main (A : R .Tensor ([3 , 3 ], "float16" )):
224+ output = R .take (A , R .const ([0 , 1 , 2 , 3 ]), axis = axis , mode = "clip" )
225+ return output
226+
227+ built = tvm .compile (Module , target = target )
228+ vm = tvm .relax .VirtualMachine (built , dev )
229+ np_input = np .random .random (size = [3 , 3 ]).astype ("float16" )
230+ tvm_input = tvm .nd .array (np_input , dev )
231+ tvm_output = vm ["main" ](tvm_input )
232+ np_expected = np .take (np_input , [0 , 1 , 2 , 3 ], axis = axis , mode = "clip" )
233+
234+ tvm .testing .assert_allclose (tvm_output .numpy (), np_expected )
235+
236+
157237if __name__ == "__main__" :
158238 tvm .testing .main ()
0 commit comments