Skip to content

Commit fda34d2

Browse files
vacu9708Youngsik Yang
authored andcommitted
[Relax][Transform] Add mode choice, NaN mode, and warning for take()
- Add a `mode` parameter to Relax’s `take()` - Add `NaN` mode to `take()` - Add unit tests covering all `take()` modes - Add a warning log for `fast` mode - Unify default modes in lower layers to `fast` for consistency with Relax
1 parent 17113f8 commit fda34d2

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

include/tvm/relax/attrs/index.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ namespace relax {
3232
/*! \brief Attributes used in take operator */
3333
struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
3434
Optional<int64_t> axis;
35+
String mode;
3536

3637
static void RegisterReflection() {
3738
namespace refl = tvm::ffi::reflection;
38-
refl::ObjectDef<TakeAttrs>().def_ro("axis", &TakeAttrs::axis,
39-
"The axis over which to select values.");
39+
refl::ObjectDef<TakeAttrs>()
40+
.def_ro("axis", &TakeAttrs::axis, "The axis over which to select values.")
41+
.def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds indices.",
42+
refl::DefaultValue("fast"));
4043
}
4144

4245
static constexpr const char* _type_key = "relax.attrs.TakeAttrs";

python/tvm/relax/op/index.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
PrimExprLike = Union[int, PrimExpr]
2727

2828

29-
def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
29+
def take(x: Expr, indices: Expr, axis: Optional[int] = None, mode: str = "fast") -> Expr:
3030
"""Take elements from a tensor along an axis.
3131
Its semantic is mostly similar to `numpy.take`
3232
(https://numpy.org/doc/stable/reference/generated/numpy.take.html),
@@ -45,12 +45,20 @@ def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
4545
The axis over which to select values.
4646
If it is none, the input tensor is required to be one-dimensional.
4747
48+
mode : str
49+
Specifies how out-of-bounds indices will behave.
50+
- fast (default): extra indices lead to seg fault (user must make sure indices are in-bound)
51+
- nan: produce NaNs for out-of-bounds indices
52+
- wrap: wrap around the indices
53+
- clip: clip to the range
54+
'clip' mode means that all indices that are too large are replaced
55+
by the index that addresses the last element along that axis.
4856
Returns
4957
-------
5058
ret : relax.Expr
5159
The taken result.
5260
"""
53-
return _ffi_api.take(x, indices, axis) # type: ignore
61+
return _ffi_api.take(x, indices, axis, mode) # type: ignore
5462

5563

5664
@args_converter.auto

src/relax/op/tensor/index.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
4444
/* relax.take */
4545
TVM_REGISTER_NODE_TYPE(TakeAttrs);
4646

47-
Expr take(Expr x, Expr indices, Optional<int64_t> axis) {
47+
Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode) {
4848
ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
4949
attrs->axis = std::move(axis);
50+
attrs->mode = std::move(mode);
5051

5152
static const Op& op = Op::Get("relax.take");
5253
return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {});

0 commit comments

Comments
 (0)