Skip to content

Commit 7883136

Browse files
vacu9708Youngsik Yang
authored andcommitted
[Relax][Transform] Add mode choice, NaN mode, and warning for take()
- Add a `mode` parameter to Relax’s `take()` - Add `NaN` mode to `take()` - Add unit tests covering all `take()` modes - Add a warning log for `fast` mode - Unify default modes in lower layers to `fast` for consistency with Relax
1 parent 17113f8 commit 7883136

File tree

8 files changed

+147
-20
lines changed

8 files changed

+147
-20
lines changed

include/tvm/relax/attrs/index.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ namespace relax {
3232
/*! \brief Attributes used in take operator */
3333
struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
3434
Optional<int64_t> axis;
35+
String mode;
3536

3637
static void RegisterReflection() {
3738
namespace refl = tvm::ffi::reflection;
38-
refl::ObjectDef<TakeAttrs>().def_ro("axis", &TakeAttrs::axis,
39-
"The axis over which to select values.");
39+
refl::ObjectDef<TakeAttrs>()
40+
.def_ro("axis", &TakeAttrs::axis, "The axis over which to select values.")
41+
.def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds indices.",
42+
refl::DefaultValue("fast"));
4043
}
4144

4245
static constexpr const char* _type_key = "relax.attrs.TakeAttrs";

include/tvm/topi/transform.h

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ inline Array<Tensor> split_n_sections(const Tensor& x, int num_sections, int axi
10081008
* \return A Tensor whose op member is the take operation
10091009
*/
10101010
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
1011-
std::string mode = "clip", std::string name = "T_take",
1011+
std::string mode = "fast", std::string name = "T_take",
10121012
std::string tag = kInjective) {
10131013
Array<PrimExpr> a_shape = a->shape;
10141014
Array<PrimExpr> out_shape = indices->shape;
@@ -1032,6 +1032,16 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
10321032
out_shape,
10331033
[&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
10341034
name, tag);
1035+
} else if (mode == "nan") {
1036+
return compute(
1037+
out_shape,
1038+
[&](const Array<Var>& out_index) {
1039+
auto idx = tvm::if_then_else(
1040+
indices(out_index) < 0 || indices(out_index) >= a_size,
1041+
tvm::FloatImm(a->dtype, std::numeric_limits<float>::quiet_NaN()), indices(out_index));
1042+
return a(UnravelIndex(idx, a_shape));
1043+
},
1044+
name, tag);
10351045
} else { // mode == "wrap"
10361046
return compute(
10371047
out_shape,
@@ -1094,7 +1104,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
10941104
* \return A Tensor whose op member is the take operation
10951105
*/
10961106
inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch_dims, int axis,
1097-
std::string mode = "clip", std::string name = "T_take",
1107+
std::string mode = "fast", std::string name = "T_take",
10981108
std::string tag = kInjective) {
10991109
if (axis < 0) {
11001110
axis += static_cast<int>(a->shape.size());
@@ -1206,6 +1216,8 @@ inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch
12061216
name, tag);
12071217
}
12081218
} else if (mode == "fast") {
1219+
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1220+
"Make sure input indices are in bound";
12091221
return compute(
12101222
out_shape,
12111223
[&](const Array<Var>& out_index) {
@@ -1224,6 +1236,29 @@ inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch
12241236
return a(real_indices);
12251237
},
12261238
name, tag);
1239+
} else if (mode == "nan") {
1240+
return compute(
1241+
out_shape,
1242+
[&](const Array<Var>& out_index) {
1243+
Array<PrimExpr> indices_position;
1244+
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1245+
indices_position.push_back(out_index[j]);
1246+
}
1247+
Array<PrimExpr> real_indices;
1248+
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1249+
real_indices.push_back(out_index[j]);
1250+
}
1251+
PrimExpr idx = get_index(indices_position);
1252+
real_indices.push_back(idx);
1253+
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1254+
real_indices.push_back(out_index[j]);
1255+
}
1256+
PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
1257+
return tvm::if_then_else(
1258+
in_bounds, a(real_indices),
1259+
tvm::tir::make_const(a->dtype, std::numeric_limits<float>::quiet_NaN()));
1260+
},
1261+
name, tag);
12271262
} else { // mode == "wrap"
12281263
return compute(
12291264
out_shape,

python/tvm/relax/op/index.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
PrimExprLike = Union[int, PrimExpr]
2727

2828

29-
def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
29+
def take(x: Expr, indices: Expr, axis: Optional[int] = None, mode: str = "fast") -> Expr:
3030
"""Take elements from a tensor along an axis.
3131
Its semantic is mostly similar to `numpy.take`
3232
(https://numpy.org/doc/stable/reference/generated/numpy.take.html),
@@ -45,12 +45,18 @@ def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
4545
The axis over which to select values.
4646
If it is none, the input tensor is required to be one-dimensional.
4747
48+
mode : str
49+
Specifies how out-of-bounds indices will behave.
50+
- fast (default): extra indices lead to seg fault (user must make sure indices are in-bound)
51+
- nan: produce NaNs for out-of-bounds indices
52+
- wrap: wrap around the indices
53+
- clip: clip to the range
4854
Returns
4955
-------
5056
ret : relax.Expr
5157
The taken result.
5258
"""
53-
return _ffi_api.take(x, indices, axis) # type: ignore
59+
return _ffi_api.take(x, indices, axis, mode) # type: ignore
5460

5561

5662
@args_converter.auto

python/tvm/relax/transform/legalize_ops/index.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@
2626

2727
@register_legalize("relax.take")
2828
def _take(bb: BlockBuilder, call: Call) -> Expr:
29-
# Currently Relax `take` operator doesn't provide the mode choices and
30-
# requires input indices to be in range.
31-
# We use fast mode, which leads to runtime error whenever some index is
32-
# out of bound.
33-
return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode="fast")
29+
# Currently "fast" is the default mode, which leads to segmentation faults
30+
# when there are out-of-bounds indices.
31+
return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode=call.attrs.mode)
3432

3533

3634
@register_legalize("relax.strided_slice")

python/tvm/topi/transform.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def split(ary, indices_or_sections, axis=0):
446446
return cpp.split(ary, indices_or_sections, axis)
447447

448448

449-
def take(a, indices, axis=None, batch_dims=0, mode="clip"):
449+
def take(a, indices, axis=None, batch_dims=0, mode="fast"):
450450
"""Take elements from an array along an axis.
451451
452452
Parameters
@@ -465,10 +465,11 @@ def take(a, indices, axis=None, batch_dims=0, mode="clip"):
465465
The number of batch dimensions. By default is 0.
466466
467467
mode : str, optional
468-
Specifies how out-of-bound indices will behave.
469-
clip - clip to the range (default)
470-
wrap - wrap around the indices
471-
fast - no clip or wrap around (user must make sure indices are in-bound)
468+
Specifies how out-of-bounds indices will behave.
469+
- fast (default): extra indices lead to seg fault (user must make sure indices are in-bound)
470+
- nan: produce NaNs for out-of-bounds indices
471+
- wrap: wrap around the indices
472+
- clip: clip to the range
472473
473474
Returns
474475
-------

src/relax/op/tensor/index.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
4444
/* relax.take */
4545
TVM_REGISTER_NODE_TYPE(TakeAttrs);
4646

47-
Expr take(Expr x, Expr indices, Optional<int64_t> axis) {
47+
Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode) {
4848
ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
4949
attrs->axis = std::move(axis);
50+
attrs->mode = std::move(mode);
5051

5152
static const Op& op = Op::Get("relax.take");
5253
return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {});
@@ -100,8 +101,10 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
100101
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
101102
}
102103

103-
int axis =
104-
attrs->axis.has_value() ? NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()) : 0;
104+
int axis = 0;
105+
if (attrs->axis.has_value()) {
106+
axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value());
107+
}
105108
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
106109
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
107110
if (data_shape == nullptr || indices_shape == nullptr) {

src/relax/op/tensor/index.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ namespace relax {
3838
* It is required to be a one-dimensional tensor which has integer dtype.
3939
* \param axis The axis over which to select values.
4040
* If it is `std::nullopt`, the input tensor is required to be one-dimensional.
41+
* \param mode The mode for handling out-of-bounds indices.
4142
* \return The taken result.
4243
*/
43-
Expr take(Expr x, Expr indices, Optional<int64_t> axis);
44+
Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode = "fast");
4445

4546
/*!
4647
* \brief Strided slice of a tensor.

tests/python/relax/test_op_take.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,85 @@ def main(A: R.Tensor(["n", "n"], "float16")):
154154
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
155155

156156

157+
@tvm.testing.parametrize_targets("llvm")
158+
def test_take_nan_mode_OOB_indices(target, dev, axis):
159+
"""Test R.take with mode="nan" and out-of-bounds indices.
160+
This test checks that out-of-bounds indices produce NaN values in the output tensor.
161+
"""
162+
163+
@I.ir_module
164+
class Module:
165+
@R.function
166+
def main(A: R.Tensor([3, 3], "float16")):
167+
output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="nan")
168+
return output
169+
170+
built = tvm.compile(Module, target=target)
171+
vm = tvm.relax.VirtualMachine(built, dev)
172+
173+
np_input = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype="float16")
174+
tvm_input = tvm.nd.array(np_input, dev)
175+
tvm_output = vm["main"](tvm_input)
176+
if axis == 0:
177+
np_expected = np.array(
178+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [np.nan, np.nan, np.nan]],
179+
dtype="float16",
180+
)
181+
elif axis == 1:
182+
np_expected = np.array(
183+
[[1.0, 2.0, 3.0, np.nan], [4.0, 5.0, 6.0, np.nan], [7.0, 8.0, 9.0, np.nan]],
184+
dtype="float16",
185+
)
186+
187+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
188+
189+
190+
@tvm.testing.parametrize_targets("llvm")
191+
def test_take_wrap_mode_OOB_indices(target, dev, axis):
192+
"""Test R.take with mode="wrap" and out-of-bounds indices.
193+
This test checks that out-of-bounds indices wrap around to the valid range.
194+
"""
195+
196+
@I.ir_module
197+
class Module:
198+
@R.function
199+
def main(A: R.Tensor([3, 3], "float16")):
200+
output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="wrap")
201+
return output
202+
203+
built = tvm.compile(Module, target=target)
204+
vm = tvm.relax.VirtualMachine(built, dev)
205+
206+
np_input = np.random.random(size=[3, 3]).astype("float16")
207+
tvm_input = tvm.nd.array(np_input, dev)
208+
tvm_output = vm["main"](tvm_input)
209+
np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="wrap")
210+
211+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
212+
213+
214+
@tvm.testing.parametrize_targets("llvm")
215+
def test_take_clip_mode_OOB_indices(target, dev, axis):
216+
"""Test R.take with mode="clip" and out-of-bounds indices.
217+
This test checks that out-of-bounds indices are clipped to the valid range.
218+
"""
219+
220+
@I.ir_module
221+
class Module:
222+
@R.function
223+
def main(A: R.Tensor([3, 3], "float16")):
224+
output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="clip")
225+
return output
226+
227+
built = tvm.compile(Module, target=target)
228+
vm = tvm.relax.VirtualMachine(built, dev)
229+
np_input = np.random.random(size=[3, 3]).astype("float16")
230+
tvm_input = tvm.nd.array(np_input, dev)
231+
tvm_output = vm["main"](tvm_input)
232+
np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="clip")
233+
234+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
235+
236+
157237
if __name__ == "__main__":
158238
tvm.testing.main()

0 commit comments

Comments
 (0)