@@ -710,16 +710,17 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
710710 const te::Tensor& end, const te::Tensor& strides,
711711 std::string name = " T_strided_slice_dynamic" ,
712712 std::string tag = topi::kInjective ) {
713+ DataType index_dtype = begin->shape [0 ]->dtype ;
713714 const int64_t num_dynamic_axes = begin->shape [0 ].as <IntImmNode>()->value ;
714715 ICHECK_EQ (end->shape [0 ].as <IntImmNode>()->value , num_dynamic_axes);
715716 ICHECK_EQ (strides->shape [0 ].as <IntImmNode>()->value , num_dynamic_axes);
716717
717718 Array<PrimExpr> begin_expr, end_expr, strides_expr;
718719 for (int64_t i = 0 ; i < num_dynamic_axes; ++i) {
719- auto i64_ind = IntImm ( DataType::Int ( 64 ) , i);
720- begin_expr.push_back (begin (i64_ind ));
721- end_expr.push_back (end (i64_ind ));
722- strides_expr.push_back (strides (i64_ind ));
720+ auto ind = make_const (index_dtype , i);
721+ begin_expr.push_back (begin (ind ));
722+ end_expr.push_back (end (ind ));
723+ strides_expr.push_back (strides (ind ));
723724 }
724725 return dynamic_strided_slice (x, begin_expr, end_expr, strides_expr, name, tag);
725726}
@@ -822,9 +823,10 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
822823 Array<Integer> end_full (end);
823824 Array<Integer> strides_full (strides);
824825
825- const IntImm one = IntImm (DataType::Int (64 ), 1 );
826- const IntImm zero = IntImm (DataType::Int (64 ), 0 );
827- const IntImm max_range = IntImm (DataType::Int (64 ), std::numeric_limits<int64_t >::max ());
826+ DataType index_dtype = begin.size () > 0 ? begin[0 ]->dtype : DataType::Int (64 );
827+ const IntImm one = IntImm (index_dtype, 1 );
828+ const IntImm zero = IntImm (index_dtype, 0 );
829+ const IntImm max_range = Downcast<IntImm>(max_value (index_dtype));
828830
829831 for (size_t i = strides.size (); i < src_tensor_dim; ++i) {
830832 strides_full.push_back (one);
0 commit comments