Skip to content

Commit a1b207c

Browse files
electricliliesylc
authored andcommitted
WithFields for Tuples (apache#9533)
1 parent a768144 commit a1b207c

File tree

15 files changed

+108
-71
lines changed

15 files changed

+108
-71
lines changed

include/tvm/relay/expr.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,20 @@ class Tuple : public Expr {
142142
TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
143143

144144
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
145+
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
145146
};
146147

148+
/*!
149+
* \brief Returns the tuple with given properties. A null property denotes 'no change'.
150+
* Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields.
151+
* \param tuple The tuple to copy
152+
* \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields =
153+
* tuple->fields.
154+
* \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span.
155+
*/
156+
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
157+
Optional<Span> opt_span = Optional<Span>(nullptr));
158+
147159
/*!
148160
* \brief Local variables used in the let expression.
149161
*

src/relay/ir/expr.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
7676
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields, Span span) {
7777
return Tuple(fields, span);
7878
});
79+
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<Span> opt_span) {
80+
Array<Expr> fields = opt_fields.value_or(tuple->fields);
81+
Span span = opt_span.value_or(tuple->span);
82+
83+
bool all_fields_unchanged = true;
84+
if (fields.size() == tuple->fields.size()) {
85+
for (size_t i = 0; i < fields.size(); i++) {
86+
all_fields_unchanged &= fields[i].same_as(tuple->fields[i]);
87+
}
88+
} else {
89+
all_fields_unchanged = false;
90+
}
91+
92+
all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span);
93+
if (!all_fields_unchanged) {
94+
TupleNode* cow_tuple_node = tuple.CopyOnWrite();
95+
cow_tuple_node->fields = fields;
96+
cow_tuple_node->span = span;
97+
}
98+
return std::move(tuple);
99+
}
79100

80101
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
81102
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {

src/relay/ir/expr_functor.cc

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,15 @@ Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op);
177177

178178
Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }
179179

180-
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
180+
Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
181181
tvm::Array<Expr> fields;
182-
bool all_fields_unchanged = true;
183-
for (auto field : op->fields) {
182+
fields.reserve(tuple_node->fields.size());
183+
184+
for (auto field : tuple_node->fields) {
184185
auto new_field = this->Mutate(field);
185186
fields.push_back(new_field);
186-
all_fields_unchanged &= new_field.same_as(field);
187-
}
188-
189-
if (all_fields_unchanged) {
190-
return GetRef<Expr>(op);
191-
} else {
192-
return Tuple(fields, op->span);
193187
}
188+
return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
194189
}
195190

196191
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {

src/relay/transforms/annotate_target.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
266266

267267
virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; }
268268

269-
Expr Rewrite_(const TupleNode* op, const Expr& post) override {
270-
auto expr = Downcast<Tuple>(post);
269+
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
270+
auto tuple = Downcast<Tuple>(post);
271271

272-
auto target_n_args = AnnotateArgs(expr->fields);
273-
auto new_expr = Tuple(std::get<1>(target_n_args));
272+
auto target_n_args = AnnotateArgs(tuple->fields);
273+
auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args)));
274274
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
275275
return std::move(new_expr);
276276
}
@@ -370,13 +370,15 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter {
370370
return new_call;
371371
}
372372

373-
Expr Rewrite_(const TupleNode* op, const Expr& post) override {
374-
auto expr = Downcast<Tuple>(post);
373+
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
374+
auto tuple = Downcast<Tuple>(post);
375375
Array<Expr> new_fields;
376-
for (auto f : expr->fields) {
376+
new_fields.reserve(tuple->fields.size());
377+
378+
for (auto f : tuple->fields) {
377379
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
378380
}
379-
return std::move(Tuple(new_fields));
381+
return WithFields(std::move(tuple), std::move(new_fields));
380382
}
381383

382384
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {

src/relay/transforms/device_planner.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,7 @@ class DeviceCapturer : public ExprMutator {
786786
for (const auto& field : tuple_node->fields) {
787787
fields.push_back(VisitChild(tuple, field));
788788
}
789-
// TODO(mbs): Avoid copy
790-
return Tuple(std::move(fields), tuple_node->span);
789+
return WithFields(std::move(tuple), std::move(fields));
791790
}
792791

793792
Expr VisitExpr_(const FunctionNode* function_node) final {

src/relay/transforms/first_order_gradient.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,13 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
195195
return ret;
196196
}
197197

198-
ADValue VisitExpr_(const TupleNode* op) final {
199-
auto tt = Downcast<TupleType>(op->checked_type());
198+
ADValue VisitExpr_(const TupleNode* tuple_node) final {
199+
auto tt = Downcast<TupleType>(tuple_node->checked_type());
200200
std::vector<ADValue> ad_fields;
201-
std::vector<Expr> field_bindings;
202-
for (const auto& f : op->fields) {
201+
Array<Expr> field_bindings;
202+
field_bindings.reserve(tuple_node->fields.size());
203+
204+
for (const auto& f : tuple_node->fields) {
203205
ADValue f_ad = VisitExpr(f);
204206
if (!dynamic_cast<ADTensor*>(f_ad.get())) {
205207
diag_ctx.EmitFatal(Diagnostic::Error(f->span)
@@ -209,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
209211
field_bindings.push_back(f_ad->get<ADTensor>().forward);
210212
}
211213
// reconstruct tuple using let-bound variables to avoid duplication
212-
auto orig = Tuple(field_bindings);
214+
auto orig = WithFields(GetRef<Tuple>(tuple_node), std::move(field_bindings));
213215
orig->checked_type_ = tt;
214216
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
215217
// for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)]

src/relay/transforms/forward_rewrite.cc

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,16 @@ class ForwardRewriter : private MixedModeMutator {
113113
}
114114
}
115115

116-
Expr Rewrite_(const TupleNode* op, const Expr& post) final {
116+
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
117117
tvm::Array<Expr> fields;
118-
bool all_fields_unchanged = true;
119-
const auto* post_node = post.as<TupleNode>();
120-
for (size_t i = 0; i < op->fields.size(); ++i) {
121-
auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
122-
fields.push_back(new_field);
123-
all_fields_unchanged &= new_field.same_as(op->fields[i]);
124-
}
118+
fields.reserve(tuple_node->fields.size());
125119

126-
if (all_fields_unchanged) {
127-
return GetRef<Expr>(op);
128-
} else {
129-
return Tuple(fields);
120+
const auto* post_tuple_node = post.as<TupleNode>();
121+
for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
122+
fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i]));
130123
}
124+
125+
return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
131126
}
132127

133128
Expr Rewrite_(const CallNode* call_node, const Expr& post) final {

src/relay/transforms/fuse_ops.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -898,14 +898,14 @@ class FuseMutator : private MixedModeMutator {
898898
}
899899
}
900900

901-
Expr Rewrite_(const TupleNode* tuple, const Expr& post) {
902-
auto* ret_group = gmap_.at(tuple)->FindRoot();
903-
if (ret_group->root_ref == tuple) {
904-
return ExprMutator::VisitExpr_(tuple);
901+
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) {
902+
auto* ret_group = gmap_.at(tuple_node)->FindRoot();
903+
if (ret_group->root_ref == tuple_node) {
904+
return ExprMutator::VisitExpr_(tuple_node);
905905
}
906906
// This tuple is an intermediate node in the group
907-
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
908-
return Tuple(new_fields);
907+
Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
908+
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
909909
}
910910

911911
Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {

src/relay/transforms/memory_alloc.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
8484
Function Rewrite(const Function& expr) { return Downcast<Function>(Mutate(expr)); }
8585

8686
private:
87-
Expr VisitExpr_(const TupleNode* tn) final {
87+
Expr VisitExpr_(const TupleNode* tuple_node) final {
8888
LetList& scope = scopes_.back();
8989
Array<Expr> new_fields;
90-
for (auto field : tn->fields) {
90+
new_fields.reserve(tuple_node->fields.size());
91+
92+
for (auto field : tuple_node->fields) {
9193
auto new_field = Mutate(field);
9294
if (new_field->IsInstance<ConstantNode>()) {
9395
Var const_var("const", Type(nullptr));
9496
new_field = scope.Push(const_var, new_field);
9597
}
9698
new_fields.push_back(new_field);
9799
}
98-
return Tuple(new_fields);
100+
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
99101
}
100102

101103
void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); }

src/relay/transforms/partial_eval.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
615615
value.push_back(ps);
616616
expr.push_back(ps->dynamic);
617617
}
618+
// Note(@electriclilies): The partial evaluator seems to do some weird stuff with sharing.
619+
// Changing Tuple(expr) to WithFields(op, expr) causes some strange failures.
618620
return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
619621
}
620622

0 commit comments

Comments
 (0)