Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* var_node) {
if (var_node->type_annotation.defined()) {
type_annotation = this->VisitType(var_node->type_annotation);
}
return WithFields(GetRef<Var>(var_node), std::move(var_node->vid), std::move(type_annotation));
return WithFields(GetRef<Var>(var_node), var_node->vid, type_annotation);
}

Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); }
Expand All @@ -183,7 +183,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
auto new_field = this->Mutate(field);
fields.push_back(new_field);
}
return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
return WithFields(GetRef<Tuple>(tuple_node), fields);
}

Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) {
Expand All @@ -203,8 +203,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) {
auto ret_type = this->VisitType(func_node->ret_type);
auto body = this->Mutate(func_node->body);

return WithFields(GetRef<Function>(func_node), std::move(params), std::move(body),
std::move(ret_type), std::move(ty_params));
return WithFields(GetRef<Function>(func_node), params, body, ret_type, ty_params);
}

Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
Expand All @@ -225,45 +224,44 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
call_args.push_back(new_arg);
}

return WithFields(GetRef<Call>(call_node), std::move(new_op), std::move(call_args), {},
std::move(ty_args));
return WithFields(GetRef<Call>(call_node), new_op, call_args, {}, ty_args);
}

Expr ExprMutator::VisitExpr_(const LetNode* let_node) {
Var var = Downcast<Var>(this->Mutate(let_node->var));
auto value = this->Mutate(let_node->value);
auto body = this->Mutate(let_node->body);

return WithFields(GetRef<Let>(let_node), std::move(var), std::move(value), std::move(body));
return WithFields(GetRef<Let>(let_node), var, value, body);
}

Expr ExprMutator::VisitExpr_(const IfNode* if_node) {
auto cond = this->Mutate(if_node->cond);
auto true_b = this->Mutate(if_node->true_branch);
auto false_b = this->Mutate(if_node->false_branch);

return WithFields(GetRef<If>(if_node), std::move(cond), std::move(true_b), std::move(false_b));
return WithFields(GetRef<If>(if_node), cond, true_b, false_b);
}

Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) {
Expr tuple = this->Mutate(get_item->tuple);
return WithFields(GetRef<TupleGetItem>(get_item), std::move(tuple));
return WithFields(GetRef<TupleGetItem>(get_item), tuple);
}

Expr ExprMutator::VisitExpr_(const RefCreateNode* ref_create) {
Expr value = this->Mutate(ref_create->value);
return WithFields(GetRef<RefCreate>(ref_create), std::move(value));
return WithFields(GetRef<RefCreate>(ref_create), value);
}

Expr ExprMutator::VisitExpr_(const RefReadNode* ref_read) {
Expr ref = this->Mutate(ref_read->ref);
return WithFields(GetRef<RefRead>(ref_read), std::move(ref));
return WithFields(GetRef<RefRead>(ref_read), ref);
}

Expr ExprMutator::VisitExpr_(const RefWriteNode* ref_write) {
Expr ref = this->Mutate(ref_write->ref);
Expr value = this->Mutate(ref_write->value);
return WithFields(GetRef<RefWrite>(ref_write), std::move(ref), std::move(value));
return WithFields(GetRef<RefWrite>(ref_write), ref, value);
}

Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); }
Expand All @@ -275,13 +273,13 @@ Expr ExprMutator::VisitExpr_(const MatchNode* match_node) {
}
Expr data = Mutate(match_node->data);

return WithFields(GetRef<Match>(match_node), std::move(data), std::move(clauses));
return WithFields(GetRef<Match>(match_node), data, clauses);
}

Clause ExprMutator::VisitClause(const Clause& clause) {
Pattern lhs = VisitPattern(clause->lhs);
Expr rhs = Mutate(clause->rhs);
return WithFields(std::move(clause), std::move(lhs), std::move(rhs));
return WithFields(clause, lhs, rhs);
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Expand Down Expand Up @@ -462,7 +460,7 @@ class ExprBinder : public MixedModeMutator, PatternMutator {

Clause VisitClause(const Clause& clause) final {
Pattern lhs = VisitPattern(clause->lhs);
return WithFields(std::move(clause), std::move(lhs), VisitExpr(clause->rhs));
return WithFields(clause, lhs, VisitExpr(clause->rhs));
}

Var VisitVar(const Var& v) final {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
auto tuple = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(tuple->fields);
auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args)));
auto new_expr = WithFields(tuple, std::get<1>(target_n_args));
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}
Expand Down Expand Up @@ -378,7 +378,7 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter {
for (auto f : tuple->fields) {
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
}
return WithFields(std::move(tuple), std::move(new_fields));
return WithFields(tuple, new_fields);
}

Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
Expand Down
30 changes: 14 additions & 16 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class RewriteOnDevices : public ExprMutator {
Expr tuple = VisitExpr(tuple_get_item_node->tuple);
OnDeviceProps props = GetOnDeviceProps(tuple);

Expr tuple_get_item = WithFields(GetRef<TupleGetItem>(tuple_get_item_node), std::move(tuple));
Expr tuple_get_item = WithFields(GetRef<TupleGetItem>(tuple_get_item_node), tuple);
if (props.body.defined() && props.is_normal()) {
VLOG(2) << "wrapping tuple get item:" << std::endl
<< PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << std::endl
Expand Down Expand Up @@ -363,8 +363,8 @@ class RewriteOnDevices : public ExprMutator {
}
expr = VisitExpr(expr);
for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*opt_var=*/{},
/*opt_value=*/std::move(std::get<1>(*itr)), /*opt_body=*/std::move(expr));
expr = WithFields(/*let=*/std::get<0>(*itr), /*opt_var=*/{},
/*opt_value=*/std::get<1>(*itr), /*opt_body=*/expr);
}
return expr;
}
Expand All @@ -378,7 +378,7 @@ class RewriteOnDevices : public ExprMutator {
<< "to be fixed to VirtualDevice " << props.virtual_device;
body = MaybeOnDeviceFixed(props.body, props.virtual_device);
}
return WithFields(GetRef<Function>(function_node), function_node->params, std::move(body));
return WithFields(GetRef<Function>(function_node), function_node->params, body);
}

Expr VisitExpr_(const CallNode* call_node) final {
Expand Down Expand Up @@ -990,7 +990,7 @@ class DeviceCapturer : public ExprMutator {
for (const auto& field : tuple_node->fields) {
fields.push_back(VisitChild(tuple, field));
}
return WithFields(std::move(tuple), std::move(fields));
return WithFields(tuple, fields);
}

Expr VisitExpr_(const FunctionNode* function_node) final {
Expand Down Expand Up @@ -1025,8 +1025,7 @@ class DeviceCapturer : public ExprMutator {
/*expected_virtual_device=*/result_virtual_device,
/*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body);

Function func = WithFields(GetRef<Function>(function_node), std::move(function_node->params),
std::move(body));
Function func = WithFields(GetRef<Function>(function_node), function_node->params, body);
return FunctionOnDevice(func, std::move(param_virtual_devices),
std::move(result_virtual_device));
}
Expand Down Expand Up @@ -1102,9 +1101,9 @@ class DeviceCapturer : public ExprMutator {
if (call_node->op == CallLoweredOp()) {
Call new_call =
CallLowered(Downcast<GlobalVar>(op), args, /*call_lowered_attrs=*/{}, /*span=*/{});
return WithFields(call, std::move(new_call->op), std::move(new_call->args));
return WithFields(call, new_call->op, new_call->args);
} else {
return WithFields(call, std::move(op), std::move(args));
return WithFields(call, op, args);
}
}

Expand Down Expand Up @@ -1145,33 +1144,32 @@ class DeviceCapturer : public ExprMutator {
Expr cond = VisitChild(ife, if_node->cond);
Expr true_branch = VisitChild(ife, if_node->true_branch);
Expr false_branch = VisitChild(ife, if_node->false_branch);
return WithFields(std::move(ife), std::move(cond), std::move(true_branch),
std::move(false_branch));
return WithFields(ife, cond, true_branch, false_branch);
}

Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
auto tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node);
Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple);
return WithFields(std::move(tuple_get_item), std::move(tuple));
return WithFields(tuple_get_item, tuple);
}

Expr VisitExpr_(const RefCreateNode* ref_create_node) final {
auto ref_create = GetRef<RefCreate>(ref_create_node);
Expr value = VisitChild(ref_create, ref_create_node->value);
return WithFields(std::move(ref_create), std::move(value));
return WithFields(ref_create, value);
}

Expr VisitExpr_(const RefReadNode* ref_read_node) final {
auto ref_read = GetRef<RefRead>(ref_read_node);
Expr ref = VisitChild(ref_read, ref_read_node->ref);
return WithFields(std::move(ref_read), std::move(ref));
return WithFields(ref_read, ref);
}

Expr VisitExpr_(const RefWriteNode* ref_write_node) final {
auto ref_write = GetRef<RefWrite>(ref_write_node);
Expr ref = VisitChild(ref_write, ref_write_node->ref);
Expr value = VisitChild(ref_write, ref_write_node->value);
return WithFields(std::move(ref_write), std::move(ref), std::move(value));
return WithFields(ref_write, ref, value);
}

Expr VisitExpr_(const MatchNode* match_node) final {
Expand All @@ -1184,7 +1182,7 @@ class DeviceCapturer : public ExprMutator {
Expr rhs = VisitChild(match, clause->rhs);
clauses.push_back(Clause(lhs, rhs));
}
return WithFields(std::move(match), std::move(data), std::move(clauses));
return WithFields(match, data, clauses);
}

VirtualDevice GetVirtualDevice(const Expr& expr) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
field_bindings.push_back(f_ad->get<ADTensor>().forward);
}
// reconstruct tuple using let-bound variables to avoid duplication
auto orig = WithFields(GetRef<Tuple>(tuple_node), std::move(field_bindings));
auto orig = WithFields(GetRef<Tuple>(tuple_node), field_bindings);
orig->checked_type_ = tt;
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
// for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)]
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/forward_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ForwardRewriter : private MixedModeMutator {
fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i]));
}

return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
return WithFields(GetRef<Tuple>(tuple_node), fields);
}

Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ class FuseMutator : private MixedModeMutator {
}
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
return WithFields(GetRef<Tuple>(tuple_node), new_fields);
}

Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
}
new_fields.push_back(new_field);
}
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
return WithFields(GetRef<Tuple>(tuple_node), new_fields);
}

void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); }
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ IRModule FlattenTupleOutputs(IRModule module) {

// Return a tuple of compiler_ends in the place of the tuple that was
// annotated with a compiler_end.
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
return WithFields(GetRef<Tuple>(tuple_node), new_fields);
}
}
return post;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ class ArgumentSplitter : public ExprRewriter {
for (int j = 0; j < argsCount; ++j) {
args.push_back(tuple_node->fields[j + startIdx]);
}
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), std::move(args));
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), args);
Expr body = MakeConcatenate(new_tuple, param->axis);
splitted[i] = StopFusion(body);
}
tvm::Array<Expr> tuple_args(splitted);
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), std::move(tuple_args));
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), tuple_args);
return MakeConcatenate(new_tuple, param->axis);
}
return post;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
for (const auto& a : tuple_node->fields) {
fields.push_back(VisitExpr(a));
}
return Compound(e, WithFields(GetRef<Tuple>(tuple_node), std::move(fields)), v);
return Compound(e, WithFields(GetRef<Tuple>(tuple_node), fields), v);
}

Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_cps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm,
std::function<Expr()> next;
next = [&]() {
return (fields.size() == tuple_node->fields.size())
? k(WithFields(GetRef<Tuple>(tuple_node), std::move(fields)))
? k(WithFields(GetRef<Tuple>(tuple_node), fields))
: VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) {
fields.push_back(v);
return next();
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
Expr tmp = push_back_one_arg(x);
fields.push_back(tmp);
}
normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields)));
normal_new_args.push_back(WithFields(tuple_new_arg, fields));
} else {
Expr tmp = push_back_one_arg(new_arg);
normal_new_args.push_back(tmp);
Expand Down Expand Up @@ -383,7 +383,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(WithFields(tuple_arg, std::move(transformed_tuple_arg)));
transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
pt++;
Expand Down