Skip to content

Commit 8d27973

Browse files
committed
[Relax][Transform] Add mode choice, NaN mode, and warning for take()
- Add a `mode` parameter to Relax’s `take()` - Add `NaN` mode to `take()` - Add unit tests covering all `take()` modes - Add a warning log for `fast` mode - Unify default modes in lower layers to `fast` for consistency with Relax
1 parent fa46d7a commit 8d27973

File tree

10 files changed

+147
-64
lines changed

10 files changed

+147
-64
lines changed

.lesshst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.less-history-file:

include/tvm/relax/attrs/index.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ namespace relax {
3232
/*! \brief Attributes used in take operator */
3333
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
3434
Optional<int64_t> axis;
35+
String mode;
3536

3637
TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") {
3738
TVM_ATTR_FIELD(axis).describe("The axis over which to select values.");
39+
TVM_ATTR_FIELD(mode).describe("The mode for handling out-of-bounds indices.");
3840
}
3941
}; // struct TakeAttrs
4042

include/tvm/topi/transform.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,16 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
10321032
out_shape,
10331033
[&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
10341034
name, tag);
1035+
} else if (mode == "nan") {
1036+
return compute(
1037+
out_shape,
1038+
[&](const Array<Var>& out_index) {
1039+
auto idx = tvm::if_then_else(
1040+
indices(out_index) < 0 || indices(out_index) >= a_size,
1041+
tvm::FloatImm(a->dtype, std::numeric_limits<float>::quiet_NaN()), indices(out_index));
1042+
return a(UnravelIndex(idx, a_shape));
1043+
},
1044+
name, tag);
10351045
} else { // mode == "wrap"
10361046
return compute(
10371047
out_shape,
@@ -1094,7 +1104,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
10941104
* \return A Tensor whose op member is the take operation
10951105
*/
10961106
inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch_dims, int axis,
1097-
std::string mode = "clip", std::string name = "T_take",
1107+
std::string mode = "fast", std::string name = "T_take",
10981108
std::string tag = kInjective) {
10991109
if (axis < 0) {
11001110
axis += static_cast<int>(a->shape.size());
@@ -1206,6 +1216,8 @@ inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch
12061216
name, tag);
12071217
}
12081218
} else if (mode == "fast") {
1219+
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1220+
"Make sure input indices are in bound";
12091221
return compute(
12101222
out_shape,
12111223
[&](const Array<Var>& out_index) {
@@ -1224,6 +1236,29 @@ inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch
12241236
return a(real_indices);
12251237
},
12261238
name, tag);
1239+
} else if (mode == "nan") {
1240+
return compute(
1241+
out_shape,
1242+
[&](const Array<Var>& out_index) {
1243+
Array<PrimExpr> indices_position;
1244+
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1245+
indices_position.push_back(out_index[j]);
1246+
}
1247+
Array<PrimExpr> real_indices;
1248+
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1249+
real_indices.push_back(out_index[j]);
1250+
}
1251+
PrimExpr idx = get_index(indices_position);
1252+
real_indices.push_back(idx);
1253+
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1254+
real_indices.push_back(out_index[j]);
1255+
}
1256+
PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
1257+
return tvm::if_then_else(
1258+
in_bounds, a(real_indices),
1259+
tvm::tir::make_const(a->dtype, std::numeric_limits<float>::quiet_NaN()));
1260+
},
1261+
name, tag);
12271262
} else { // mode == "wrap"
12281263
return compute(
12291264
out_shape,

python/tvm/relax/op/index.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
PrimExprLike = Union[int, PrimExpr]
2727

2828

29-
def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
29+
def take(x: Expr, indices: Expr, axis: Optional[int] = None, mode: str = "fast") -> Expr:
3030
"""Take elements from a tensor along an axis.
3131
Its semantic is mostly similar to `numpy.take`
3232
(https://numpy.org/doc/stable/reference/generated/numpy.take.html),
@@ -45,12 +45,20 @@ def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
4545
The axis over which to select values.
4646
If it is none, the input tensor is required to be one-dimensional.
4747
48+
mode : str
49+
Specifies how out-of-bounds indices will behave.
50+
- fast (default): extra indices lead to seg fault (user must make sure indices are in-bound)
51+
- nan: produce NaNs for out-of-bounds indices
52+
- wrap: wrap around the indices
53+
- clip: clip to the range
54+
'clip' mode means that all indices that are too large are replaced
55+
by the index that addresses the last element along that axis.
4856
Returns
4957
-------
5058
ret : relax.Expr
5159
The taken result.
5260
"""
53-
return _ffi_api.take(x, indices, axis) # type: ignore
61+
return _ffi_api.take(x, indices, axis, mode) # type: ignore
5462

5563

5664
@args_converter.auto

python/tvm/relax/transform/legalize_ops/index.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@
2626

2727
@register_legalize("relax.take")
2828
def _take(bb: BlockBuilder, call: Call) -> Expr:
29-
# Currently Relax `take` operator doesn't provide the mode choices and
30-
# requires input indices to be in range.
31-
# We use fast mode, which leads to runtime error whenever some index is
32-
# out of bound.
33-
return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode="fast")
29+
# Currently "fast" is the default mode, which leads to segmentation faults
30+
# when there are out-of-bounds indices.
31+
return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode=call.attrs.mode)
3432

3533

3634
@register_legalize("relax.strided_slice")

python/tvm/topi/transform.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
"""Injective transformation operators"""
1919
from __future__ import absolute_import as _abs
2020

21-
from math import pi
22-
import numpy as np
23-
2421
import tvm
2522
from tvm import te, topi
2623

@@ -99,8 +96,7 @@ def _compute(*idxs):
9996
axis_index = 0
10097
for i in range(0, len(idxs)):
10198
if i not in real_axis:
102-
dim = tvm.tir.if_then_else(a.shape[len(indices)] != 1, idxs[i], 0)
103-
indices.append(dim)
99+
indices.append(idxs[i])
104100
axis_index += 1
105101
return a(*indices)
106102

@@ -446,7 +442,7 @@ def split(ary, indices_or_sections, axis=0):
446442
return cpp.split(ary, indices_or_sections, axis)
447443

448444

449-
def take(a, indices, axis=None, batch_dims=0, mode="clip"):
445+
def take(a, indices, axis=None, batch_dims=0, mode="fast"):
450446
"""Take elements from an array along an axis.
451447
452448
Parameters
@@ -465,10 +461,13 @@ def take(a, indices, axis=None, batch_dims=0, mode="clip"):
465461
The number of batch dimensions. By default is 0.
466462
467463
mode : str, optional
468-
Specifies how out-of-bound indices will behave.
469-
clip - clip to the range (default)
470-
wrap - wrap around the indices
471-
fast - no clip or wrap around (user must make sure indices are in-bound)
464+
Specifies how out-of-bounds indices will behave.
465+
- fast (default): extra indices lead to seg fault (user must make sure indices are in-bound)
466+
- nan: produce NaNs for out-of-bounds indices
467+
- wrap: wrap around the indices
468+
- clip: clip to the range
469+
'clip' mode means that all indices that are too large are replaced
470+
by the index that addresses the last element along that axis.
472471
473472
Returns
474473
-------
@@ -1109,45 +1108,3 @@ def index_tensor(data, indices):
11091108
z = topi.index_tensor(x, [row, col]) # shape (2, 3)
11101109
"""
11111110
return topi.adv_index(data, indices)
1112-
1113-
1114-
def hamming_window(window_size, periodic, alpha, beta, dtype):
1115-
"""Hamming window function.
1116-
1117-
Parameters
1118-
----------
1119-
window_size: tvm.Expr
1120-
The size of returned window.
1121-
1122-
periodic: tvm.Expr
1123-
If True, returns a window to be used as periodic function.
1124-
If False, return a symmetric window.
1125-
1126-
alpha: tvm.Expr
1127-
The co-efficient alpha.
1128-
1129-
beta: tvm.Expr
1130-
The co-efficient beta.
1131-
1132-
Returns
1133-
-------
1134-
ret : tvm.te.Tensor
1135-
The result tensor.
1136-
"""
1137-
if window_size == 1:
1138-
return topi.const_vector(np.array([1], dtype=dtype))
1139-
1140-
periodic = topi.cast(periodic, "bool")
1141-
1142-
if periodic:
1143-
window_size += 1
1144-
1145-
index = topi.arange(0, window_size, dtype=dtype)
1146-
angular_freq = 2 * pi * index / (window_size - 1)
1147-
cos_values = topi.cos(angular_freq)
1148-
window = topi.cast(alpha - beta * cos_values, dtype=dtype)
1149-
1150-
if periodic:
1151-
return topi.strided_slice(window, [0], [window.shape[0] - 1])
1152-
1153-
return window

src/relax/op/tensor/index.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ namespace relax {
3939
/* relax.take */
4040
TVM_REGISTER_NODE_TYPE(TakeAttrs);
4141

42-
Expr take(Expr x, Expr indices, Optional<int64_t> axis) {
42+
Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode) {
4343
ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
4444
attrs->axis = std::move(axis);
45+
attrs->mode = std::move(mode);
4546

4647
static const Op& op = Op::Get("relax.take");
4748
return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {});

src/relax/op/tensor/index.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ namespace relax {
3838
* It is required to be a one-dimensional tensor which has integer dtype.
3939
* \param axis The axis over which to select values.
4040
* If it is `std::nullopt`, the input tensor is required to be one-dimensional.
41+
* \param mode The mode for handling out-of-bounds indices.
4142
* \return The taken result.
4243
*/
43-
Expr take(Expr x, Expr indices, Optional<int64_t> axis);
44+
Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode);
4445

4546
/*!
4647
* \brief Strided slice of a tensor.

src/relax/transform/reorder_take_after_matmul.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, Map<DFPattern, Expr>)>> Crea
9595
// out_table.shape = [*batch, table_size]
9696
auto out_table = matmul(lhs, weights, DataType::Void());
9797
// new_output.shape = [*batch, outfeatures]
98-
auto new_output = take(out_table, indices, matmul_sinfo->ndim - 1);
98+
auto new_output = take(out_table, indices, matmul_sinfo->ndim - 1, attrs->mode);
9999

100100
return new_output;
101101
} else if (lhs_sinfo->ndim == 3 && weights_sinfo->ndim == 3 && indices_sinfo->ndim == 1 &&
@@ -132,7 +132,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, Map<DFPattern, Expr>)>> Crea
132132
// operations.
133133

134134
// duplicated_output.shape = [batch1, batch2, batch1, outfeatures]
135-
auto duplicated_output = take(indexed_output, indices, 2);
135+
auto duplicated_output = take(indexed_output, indices, 2, attrs->mode);
136136
// new_output.shape = [batch1, batch2, outfeatures]
137137
auto new_output = einsum(Tuple({duplicated_output}), "ijik->ijk");
138138

tests/python/relax/test_op_take.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,85 @@ def main(A: R.Tensor(["n", "n"], "float16")):
154154
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
155155

156156

157+
@tvm.testing.parametrize_targets("llvm")
158+
def test_take_nan_mode_OOB_indices(target, dev, axis):
159+
"""Test R.take with mode="nan" and out-of-bounds indices.
160+
This test checks that out-of-bounds indices produce NaN values in the output tensor.
161+
"""
162+
163+
@I.ir_module
164+
class Module:
165+
@R.function
166+
def main(A: R.Tensor([3, 3], "float16")):
167+
output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="nan")
168+
return output
169+
170+
built = tvm.compile(Module, target=target)
171+
vm = tvm.relax.VirtualMachine(built, dev)
172+
173+
np_input = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype="float16")
174+
tvm_input = tvm.nd.array(np_input, dev)
175+
tvm_output = vm["main"](tvm_input)
176+
if axis == 0:
177+
np_expected = np.array(
178+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [np.nan, np.nan, np.nan]],
179+
dtype="float16",
180+
)
181+
elif axis == 1:
182+
np_expected = np.array(
183+
[[1.0, 2.0, 3.0, np.nan], [4.0, 5.0, 6.0, np.nan], [7.0, 8.0, 9.0, np.nan]],
184+
dtype="float16",
185+
)
186+
187+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
188+
189+
190+
@tvm.testing.parametrize_targets("llvm")
191+
def test_take_wrap_mode_OOB_indices(target, dev, axis):
192+
"""Test R.take with mode="wrap" and out-of-bounds indices.
193+
This test checks that out-of-bounds indices wrap around to the valid range.
194+
"""
195+
196+
@I.ir_module
197+
class Module:
198+
@R.function
199+
def main(A: R.Tensor([3, 3], "float16")):
200+
output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="wrap")
201+
return output
202+
203+
built = tvm.compile(Module, target=target)
204+
vm = tvm.relax.VirtualMachine(built, dev)
205+
206+
np_input = np.random.random(size=[3, 3]).astype("float16")
207+
tvm_input = tvm.nd.array(np_input, dev)
208+
tvm_output = vm["main"](tvm_input)
209+
np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="wrap")
210+
211+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
212+
213+
214+
@tvm.testing.parametrize_targets("llvm")
215+
def test_take_clip_mode_OOB_indices(target, dev, axis):
216+
"""Test R.take with mode="clip" and out-of-bounds indices.
217+
This test checks that out-of-bounds indices are clipped to the valid range.
218+
"""
219+
220+
@I.ir_module
221+
class Module:
222+
@R.function
223+
def main(A: R.Tensor([3, 3], "float16")):
224+
output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="clip")
225+
return output
226+
227+
built = tvm.compile(Module, target=target)
228+
vm = tvm.relax.VirtualMachine(built, dev)
229+
np_input = np.random.random(size=[3, 3]).astype("float16")
230+
tvm_input = tvm.nd.array(np_input, dev)
231+
tvm_output = vm["main"](tvm_input)
232+
np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="clip")
233+
234+
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
235+
236+
157237
if __name__ == "__main__":
158238
tvm.testing.main()

0 commit comments

Comments
 (0)