Skip to content

Commit c3647bb

Browse files
committed
[FFI] More strict tuple constructor checking (apache#18023)
0;276;0c# This is the 1st commit message: [FFI] More strict tuple constructor checking This PR fixes a case where tuple constructor mismatches during vector forwarding. In this case UType&& was mistakenly matched and used. A testcase is added.
1 parent 60cba0e commit c3647bb

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

include/tvm/ffi/container/container_details.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,16 @@ class IterAdapter {
199199

200200
IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); }
201201

202+
IterAdapter& operator+=(difference_type offset) {
203+
iter_ += offset;
204+
return *this;
205+
}
206+
207+
IterAdapter& operator-=(difference_type offset) {
208+
iter_ -= offset;
209+
return *this;
210+
}
211+
202212
template <typename T = IterAdapter>
203213
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
204214
typename T::difference_type>::type inline

include/tvm/ffi/container/tuple.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,13 @@ class Tuple : public ObjectRef {
5555
template <typename... UTypes,
5656
typename = std::enable_if_t<(details::type_contains_v<Types, UTypes> && ...), int>>
5757
Tuple(Tuple<UTypes...>&& other) : ObjectRef(std::move(other)) {}
58-
template <typename... UTypes>
59-
explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward<UTypes>(args)...)) {
60-
static_assert(sizeof...(Types) == sizeof...(UTypes), "Tuple size mismatch");
61-
}
58+
59+
template <typename... UTypes,
60+
typename = std::enable_if_t<sizeof...(Types) == sizeof...(UTypes) &&
61+
!(sizeof...(Types) == 1 &&
62+
(std::is_same_v<std::remove_cv_t<UTypes>, Tuple<Types>> &&
63+
...))>>
64+
explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward<UTypes>(args)...)) {}
6265

6366
TVM_FFI_INLINE Tuple& operator=(const Tuple<Types...>& other) {
6467
data_ = other.data_;

tests/cpp/test_tuple.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,32 @@ TEST(Tuple, Upcast) {
136136
static_assert(details::type_contains_v<Tuple<Any, float>, Tuple<int, float>>);
137137
static_assert(details::type_contains_v<Tuple<TNumber, float>, Tuple<TInt, float>>);
138138
}
139+
140+
TEST(Tuple, ArrayIterForwarding) {
141+
Tuple<TInt, TInt> t0(1, 2);
142+
Tuple<TInt, TInt> t1(3, 4);
143+
Array<Tuple<TInt, TInt>> arr0 = {t0, t1};
144+
std::vector<Tuple<TInt, TInt>> vec0 = {t0};
145+
vec0.insert(vec0.end(), arr0.begin(), arr0.end());
146+
EXPECT_EQ(vec0.size(), 3);
147+
EXPECT_EQ(vec0[0].get<0>()->value, 1);
148+
EXPECT_EQ(vec0[0].get<1>()->value, 2);
149+
EXPECT_EQ(vec0[1].get<0>()->value, 1);
150+
EXPECT_EQ(vec0[1].get<1>()->value, 2);
151+
EXPECT_EQ(vec0[2].get<0>()->value, 3);
152+
EXPECT_EQ(vec0[2].get<1>()->value, 4);
153+
}
154+
155+
TEST(Tuple, ArrayIterForwardSingleElem) {
156+
Tuple<TInt> t0(1);
157+
Tuple<TInt> t1(2);
158+
Array<Tuple<TInt>> arr0 = {t0, t1};
159+
std::vector<Tuple<TInt>> vec0 = {t0};
160+
vec0.insert(vec0.end(), arr0.begin(), arr0.end());
161+
EXPECT_EQ(vec0.size(), 3);
162+
EXPECT_EQ(vec0[0].get<0>()->value, 1);
163+
EXPECT_EQ(vec0[1].get<0>()->value, 1);
164+
EXPECT_EQ(vec0[2].get<0>()->value, 2);
165+
}
166+
139167
} // namespace

0 commit comments

Comments
 (0)