Skip to content

Commit 8ca8e38

Browse files
[Relay] WithFields method for Call, Function, Var, TupleGetItem, If, Let, RefCreate, RefRead, RefWrite, Match, and Clause (#9569)
* Implement WithFields for Relay exprs * lint
1 parent 0a4cc89 commit 8ca8e38

File tree

9 files changed

+571
-133
lines changed

9 files changed

+571
-133
lines changed

include/tvm/relay/adt.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,25 @@ class Clause : public ObjectRef {
260260
TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
261261

262262
TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
263+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ClauseNode);
263264
};
264265

266+
/*!
267+
* \brief Returns the clause with given properties. A null property denotes 'no change'.
268+
* Returns clause if all properties are unchanged. Otherwise, returns a copy with the new fields.
269+
* \param clause The clause to copy.
270+
* \param opt_lhs The (optional) lhs for the copied clause. If none, ret_clause->lhs = clause->lhs.
271+
* \param opt_rhs The (optional) rhs for the copied clause. If none,
272+
* ret_clause->rhs = clause->rhs.
273+
* \return If all
274+
* properties are null or the same as the property in the input clause (i.e., opt_lhs is null or
275+
* opt_lhs.value() == clause->lhs, etc.), then we return clause. Otherwise, we return a copy of
276+
* clause with the different fields overwritten. (i.e., if opt_lhs.value() != clause->lhs, then
277+
* ret_clause->lhs = opt_lhs.value()).
278+
*/
279+
Clause WithFields(Clause clause, Optional<Pattern> opt_lhs = Optional<Pattern>(),
280+
Optional<Expr> opt_rhs = Optional<Expr>());
281+
265282
/*! \brief ADT pattern matching exression. */
266283
class Match;
267284
/*! \brief Match container node. */
@@ -315,8 +332,30 @@ class Match : public Expr {
315332
TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
316333

317334
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
335+
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchNode);
318336
};
319337

338+
/*!
339+
* \brief Returns the match with given properties. A null property denotes 'no change'.
340+
* Returns match if all properties are unchanged. Otherwise, returns a copy with the new fields.
341+
* \param match The match to copy.
342+
* \param opt_data The (optional) data for the copied match. If none, ret_match->data = match->data.
343+
* \param opt_clauses The (optional) clauses for the copied match. If none, ret_match->clauses =
344+
* match->clauses.
345+
* \param opt_complete The (optional) complete for the copied match. If none, ret_match->complete =
346+
* match->complete.
347+
* \param opt_span The (optional) span for the copied match. If none, ret_match->span = match->span.
348+
* \return If all properties are null or the same as the
349+
* property in the input match (i.e., opt_clauses is null or opt_clauses.value() == match->clauses,
350+
* etc.), then we return match. Otherwise, we return a copy of match with the different fields
351+
* overwritten. (i.e., if opt_clauses.value() != match->clauses, then ret_match->clauses =
352+
* opt_clauses.value()).
353+
*/
354+
Match WithFields(Match match, Optional<Expr> opt_data = Optional<Expr>(),
355+
Optional<Array<Clause>> opt_clauses = Optional<Array<Clause>>(),
356+
Optional<Bool> opt_complete = Optional<Bool>(),
357+
Optional<Span> opt_span = Optional<Span>());
358+
320359
} // namespace relay
321360
} // namespace tvm
322361

include/tvm/relay/expr.h

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,26 @@ class Var : public Expr {
230230
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());
231231

232232
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
233+
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
233234
};
234235

236+
/*!
237+
* \brief Returns the var with given properties. A null property denotes 'no change'.
238+
* Returns var if all properties are unchanged. Otherwise, returns a copy with the new fields.
239+
* \param var The var to copy.
240+
* \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid.
241+
* \param opt_type_annotation The (optional) type_annotation for the copied var. If none,
242+
* ret_var->type_annotation = var->type_annotation.
243+
* \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span.
244+
* \return If all properties are null or the same as the property in the input var
245+
* (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise,
246+
* we return a copy of call with the different fields overwritten. (i.e., if
247+
* opt_vid.value() != var->vid, then ret_var->vid = opt_.value()).
248+
*/
249+
Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(),
250+
Optional<Type> opt_type_annotation = Optional<Type>(),
251+
Optional<Span> opt_span = Optional<Span>());
252+
235253
/*!
236254
* \brief Call corresponds to operator invocation.
237255
* Corresponds to the operator in computational graph terminology.
@@ -331,8 +349,31 @@ class Call : public Expr {
331349
Array<Type> type_args = Array<Type>(), Span span = Span());
332350

333351
TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
352+
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
334353
};
335354

355+
/*!
356+
* \brief Returns the call with given properties. A null property denotes 'no change'.
357+
* Returns call if all properties are unchanged. Otherwise, returns a copy with the new fields.
358+
* \param call The call to copy.
359+
* \param opt_op The (optional) op for the copied call. If none, ret_call->op = call->op.
360+
* \param opt_args The (optional) args for the copied call. If none, ret_call->args = call->args.
361+
* \param opt_attrs The (optional) attrs for the copied call. If none, ret_call->attrs =
362+
* call->attrs.
363+
* \param opt_type_args The (optional) type args for the copied call. If none,
364+
* ret_call->type_args = call->type_args.
365+
* \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span.
366+
* \return If all properties are null or the same as the property in the input call
367+
* (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we
368+
* return a copy of call with the different fields overwritten. (i.e., if opt_op.value() !=
369+
* call->op, then ret_call->op = opt_op.value()).
370+
*/
371+
Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
372+
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
373+
Optional<Attrs> opt_attrs = Optional<Attrs>(),
374+
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(),
375+
Optional<Span> opt_span = Optional<Span>());
376+
336377
/*!
337378
* \brief Let binding that binds a local var and optionally a type annotation.
338379
*
@@ -405,8 +446,27 @@ class Let : public Expr {
405446
TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span());
406447

407448
TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
449+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
408450
};
409451

452+
/*!
453+
* \brief Returns the let with given properties. A null property denotes 'no change'.
454+
* Returns let if all properties are unchanged. Otherwise, returns a copy with the new fields.
455+
* \param let The let to copy.
456+
* \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op.
457+
* \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args.
458+
* \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs.
459+
* \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span.
460+
* \return If all properties are null or the same as the property in the input let (i.e., opt_var is
461+
* null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of
462+
* let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then
463+
* ret_let->var = opt_var.value()).
464+
*/
465+
Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(),
466+
Optional<Expr> opt_value = Optional<Expr>(),
467+
Optional<Expr> opt_body = Optional<Expr>(),
468+
Optional<Span> opt_span = Optional<Span>());
469+
410470
/*!
411471
* \brief Condition expression
412472
*
@@ -466,8 +526,32 @@ class If : public Expr {
466526
TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
467527

468528
TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
529+
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
469530
};
470531

532+
/*!
533+
* \brief Returns the if_expr with given properties. A null property denotes 'no change'.
534+
* Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields.
535+
* \param if_expr The if expression to copy.
536+
* \param opt_cond The (optional) cond for the copied if_expr. If none, ret_if->cond =
537+
* if_expr->cond.
538+
* \param opt_true_branch The (optional) true_branch for the copied if_expr. If none,
539+
* ret_if->true_branch = ret_if->false_branch.
540+
* \param opt_false_branch The (optional) false_branch
541+
* for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch.
542+
* \param opt_span
543+
* The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span.
544+
* \return If all
545+
* properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or
546+
* opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of
547+
* if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then
548+
* ret_if->cond = opt_cond.value()).
549+
*/
550+
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
551+
Optional<Expr> opt_true_branch = Optional<Expr>(),
552+
Optional<Expr> opt_false_branch = Optional<Expr>(),
553+
Optional<Span> opt_span = Optional<Span>());
554+
471555
/*! \brief Get index-th field out of a tuple. */
472556
class TupleGetItem;
473557
class TupleGetItemNode : public ExprNode {
@@ -508,8 +592,30 @@ class TupleGetItem : public Expr {
508592
TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
509593

510594
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
595+
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode);
511596
};
512597

598+
/*!
599+
* \brief Returns the tuple_get_item with given properties. A null property denotes 'no change'.
600+
* Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields.
601+
* \param tuple_get_item The tuple_get_item to copy.
602+
* \param opt_tuple The (optional) tuple for the copied tuple_get_item. If none,
603+
* ret_tuple_get_item->tuple = tuple_get_item->tuple.
604+
* \param opt_index The (optional) index for the copied tuple_get_item. If none,
605+
* ret_tuple_get_item->index = tuple_get_item->index.
606+
* \param
607+
* opt_span The (optional) span for the copied tuple_get_item. If none,
608+
* ret_tuple_get_item->span = tuple_get_item->span.
609+
* \return If all properties are null or the same as the property in the input tuple_get_item
610+
* (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return
611+
* tuple_get_item. Otherwise, we return a copy of tuple_get_item with the different fields
612+
* overwritten. (i.e., if opt_tuple.value() != tuple_get_item->tuple, then
613+
* ret_tuple_get_item->tuple = opt_tuple.value()).
614+
*/
615+
TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
616+
Optional<Integer> opt_index = Optional<Integer>(),
617+
Optional<Span> opt_span = Optional<Span>());
618+
513619
/*! \brief Create a new Reference out of initial value. */
514620
class RefCreate;
515621
class RefCreateNode : public ExprNode {
@@ -547,8 +653,27 @@ class RefCreate : public Expr {
547653
TVM_DLL explicit RefCreate(Expr value, Span span = Span());
548654

549655
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
656+
TVM_DEFINE_OBJECT_REF_COW_METHOD(RefCreateNode);
550657
};
551658

659+
/*!
660+
* \brief Returns the ref create with given properties. A null property denotes 'no change'.
661+
* Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new
662+
* fields.
663+
* \param ref_create The ref_create to copy.
664+
* \param opt_value The (optional) value for the copied ref_create. If none,
665+
* ret_ref_create->value = ref_create->value.
666+
* \param opt_span The (optional) span for the copied ref_create. If none,
667+
* ret_ref_create->span = ref_create->span.
668+
* \return If all properties are null or the same as the property in the input ref_create
669+
* (i.e., opt_value is null or opt_value.value() == ref_create->value, etc.), then we return
670+
* ref_create. Otherwise, we return a copy of ref_create with the different fields overwritten.
671+
* (i.e., if opt_value.value() != ref_create->value, then
672+
* ret_ref_create->value = opt_value.value()).
673+
*/
674+
RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(),
675+
Optional<Span> opt_span = Optional<Span>());
676+
552677
/*! \brief Get value out of Reference. */
553678
class RefRead;
554679
class RefReadNode : public ExprNode {
@@ -586,7 +711,26 @@ class RefRead : public Expr {
586711
TVM_DLL explicit RefRead(Expr ref, Span span = Span());
587712

588713
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
714+
TVM_DEFINE_OBJECT_REF_COW_METHOD(RefReadNode);
589715
};
716+
717+
/*!
718+
* \brief Returns the ref read with given properties. A null property denotes 'no change'.
719+
* Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new fields.
720+
* \param ref_read The ref_read to copy.
721+
* \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref =
722+
* ref_read->ref.
723+
* \param opt_span
724+
* The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span.
725+
* \return If all properties are null or the same as the property in the input ref_read
726+
* (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read.
727+
* Otherwise, we return a copy of ref_read with the different fields overwritten.
728+
* (i.e., if opt_ref.value() != ref_read->ref, then
729+
* ret_ref_read->ref = opt_ref.value()).
730+
*/
731+
RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(),
732+
Optional<Span> opt_span = Optional<Span>());
733+
590734
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
591735
class RefWrite;
592736
class RefWriteNode : public ExprNode {
@@ -629,8 +773,29 @@ class RefWrite : public Expr {
629773
TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span());
630774

631775
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
776+
TVM_DEFINE_OBJECT_REF_COW_METHOD(RefWriteNode);
632777
};
633778

779+
/*!
780+
* \brief Returns the ref write with given properties. A null property denotes 'no change'.
781+
* Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new fields.
782+
* \param ref_write The ref_write to copy.
783+
* \param opt_ref The (optional) ref for the copied ref_write. If none,
784+
* ret_ref_write->ref = ref_write->ref.
785+
* \param opt_value The (optional) value for the copied ref_write. If none,
786+
* ret_ref_write->value = ref_write->value.
787+
* \param opt_span
788+
* The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span.
789+
* \return If all properties are null or the same as the property in the input ref_write
790+
* (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write.
791+
* Otherwise, we return a copy of ref_write with the different fields overwritten.
792+
* (i.e., if ref_write.value() != ref_write->ref, then
793+
* ret_ref_write->ref = opt_ref.value()).
794+
*/
795+
RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(),
796+
Optional<Expr> opt_value = Optional<Expr>(),
797+
Optional<Span> opt_span = Optional<Span>());
798+
634799
/*!
635800
* \brief Base class of the temporary expression.
636801
*

include/tvm/relay/function.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,35 @@ class Function : public BaseFunc {
119119
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
120120
};
121121

122+
/*!
123+
* \brief Returns the function with given properties. A null property denotes 'no change'.
124+
* Returns function if all properties are unchanged. Otherwise, returns a copy with the new fields.
125+
* \param function The function to copy.
126+
* \param opt_params The (optional) params for the copied function. If none,
127+
* ret_function->params = function->params.
128+
* \param opt_body The (optional) body for the copied function. If none,
129+
* ret_function->body = function->body.
130+
* \param opt_ret_type The (optional) return type for the copied function. If none,
131+
* ret_function->ret_type = function->ret_type.
132+
* \param opt_ty_params The (optional) type params for the copied function. If none,
133+
* ret_function->type_params = function->type_params.
134+
* \param opt_attrs
135+
* The (optional) attributes for the copied function. If none,
136+
* ret_function->attrs = function->attrs.
137+
* \param opt_span The (optional) span for the copied function. If none,
138+
* ret_function->span = function->span.
139+
* \return If all properties are null or the same as the property in the input function
140+
* (i.e., opt_params is null or opt_params.value() == function->params, etc.), then we return
141+
* function. Otherwise, we return a copy of function with the different fields overwritten. (i.e.,
142+
* if opt_params.value() != function->params, then ret_function->params = opt_params.value()).
143+
*/
144+
Function WithFields(Function function, Optional<Array<Var>> opt_params = Optional<Array<Var>>(),
145+
Optional<Expr> opt_body = Optional<Expr>(),
146+
Optional<Type> opt_ret_type = Optional<Type>(),
147+
Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
148+
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
149+
Optional<Span> opt_span = Optional<Span>());
150+
122151
/*!
123152
* \brief namespace of the attributes that can be attached to a relay::Function.
124153
*/

src/relay/ir/adt.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,20 @@ Clause::Clause(Pattern lhs, Expr rhs) {
104104
data_ = std::move(n);
105105
}
106106

107+
Clause WithFields(Clause clause, Optional<Pattern> opt_lhs, Optional<Expr> opt_rhs) {
108+
Pattern lhs = opt_lhs.value_or(clause->lhs);
109+
Expr rhs = opt_rhs.value_or(clause->rhs);
110+
111+
bool unchanged = lhs.same_as(clause->lhs) && rhs.same_as(clause->rhs);
112+
113+
if (!unchanged) {
114+
ClauseNode* cow_clause_node = clause.CopyOnWrite();
115+
cow_clause_node->lhs = lhs;
116+
cow_clause_node->rhs = rhs;
117+
}
118+
return std::move(clause);
119+
}
120+
107121
TVM_REGISTER_NODE_TYPE(ClauseNode);
108122

109123
TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) {
@@ -125,6 +139,38 @@ Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete, Span span) {
125139
data_ = std::move(n);
126140
}
127141

142+
Match WithFields(Match match, Optional<Expr> opt_data, Optional<Array<Clause>> opt_clauses,
143+
Optional<Bool> opt_complete, Optional<Span> opt_span) {
144+
Expr data = opt_data.value_or(match->data);
145+
Array<Clause> clauses = opt_clauses.value_or(match->clauses);
146+
Bool complete = opt_complete.value_or(Bool(match->complete));
147+
Span span = opt_span.value_or(match->span);
148+
149+
bool unchanged =
150+
data.same_as(match->data) && (complete == match->complete) && span.same_as(match->span);
151+
152+
// Check that all clauses are unchanged
153+
if (unchanged) {
154+
bool all_clauses_unchanged = true;
155+
if (clauses.size() == match->clauses.size()) {
156+
for (size_t i = 0; i < clauses.size(); i++) {
157+
all_clauses_unchanged &= clauses[i].same_as(match->clauses[i]);
158+
}
159+
} else {
160+
all_clauses_unchanged = false;
161+
}
162+
unchanged &= all_clauses_unchanged;
163+
}
164+
if (!unchanged) {
165+
MatchNode* cow_match_node = match.CopyOnWrite();
166+
cow_match_node->data = data;
167+
cow_match_node->clauses = clauses;
168+
cow_match_node->complete = complete;
169+
cow_match_node->span = span;
170+
}
171+
return std::move(match);
172+
}
173+
128174
TVM_REGISTER_NODE_TYPE(MatchNode);
129175

130176
TVM_REGISTER_GLOBAL("relay.ir.Match")

0 commit comments

Comments
 (0)