Skip to content

Commit 3a11abc

Browse files
wrongtest-intellifyongwww
authored andcommitted
[TOPI] Fix index dtype in topi strided_slice (apache#14022)
try use more proper index dtype than i64 to avoid integer arith issue in lowering.
1 parent caf7aa1 commit 3a11abc

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

include/tvm/topi/transform.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -710,16 +710,17 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
710710
const te::Tensor& end, const te::Tensor& strides,
711711
std::string name = "T_strided_slice_dynamic",
712712
std::string tag = topi::kInjective) {
713+
DataType index_dtype = begin->shape[0]->dtype;
713714
const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
714715
ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
715716
ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
716717

717718
Array<PrimExpr> begin_expr, end_expr, strides_expr;
718719
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
719-
auto i64_ind = IntImm(DataType::Int(64), i);
720-
begin_expr.push_back(begin(i64_ind));
721-
end_expr.push_back(end(i64_ind));
722-
strides_expr.push_back(strides(i64_ind));
720+
auto ind = make_const(index_dtype, i);
721+
begin_expr.push_back(begin(ind));
722+
end_expr.push_back(end(ind));
723+
strides_expr.push_back(strides(ind));
723724
}
724725
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
725726
}
@@ -822,9 +823,10 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
822823
Array<Integer> end_full(end);
823824
Array<Integer> strides_full(strides);
824825

825-
const IntImm one = IntImm(DataType::Int(64), 1);
826-
const IntImm zero = IntImm(DataType::Int(64), 0);
827-
const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());
826+
DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
827+
const IntImm one = IntImm(index_dtype, 1);
828+
const IntImm zero = IntImm(index_dtype, 0);
829+
const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));
828830

829831
for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
830832
strides_full.push_back(one);

0 commit comments

Comments
 (0)