Skip to content

Commit df7aaa6

Browse files
Michalis Papapdimitrioumbs-octoml
andcommitted
[Relay][AST] Add WithFields for Constant and GlobalVar nodes
Co-authored-by: Mark Shields <[email protected]>
1 parent af8569c commit df7aaa6

File tree

4 files changed

+74
-2
lines changed

4 files changed

+74
-2
lines changed

include/tvm/ir/expr.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,10 @@ class GlobalVarNode : public RelayExprNode {
260260
*/
261261
class GlobalVar : public RelayExpr {
262262
public:
263-
TVM_DLL explicit GlobalVar(String name_hint, Type type = {});
263+
TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});
264264

265265
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
266+
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
266267
};
267268

268269
// PrimExprs that are useful as runtime containers.

include/tvm/relay/expr.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@
3939
#include "./type.h"
4040

4141
namespace tvm {
42+
43+
/*!
44+
* \brief Returns the global_var with given properties. A null property denotes 'no change'.
45+
* Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields.
46+
* \param global_var The global var to copy
47+
* \param opt_name_hint The (optional) op name of the global var
48+
* \param opt_type The (optional) type for the global var
49+
* \param opt_virtual_device The (optional) virtual_device for the copied constant. If none,
50+
* ret_constant->virtual_device = constant->virtual_device.
51+
* \param opt_span The (optional) span for the copied global var. If none,
52+
* ret_constant->span = constant->span.
53+
*/
54+
GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint = {},
55+
Optional<Type> opt_type = {}, Optional<VirtualDevice> opt_virtual_device = {},
56+
Optional<Span> opt_span = {});
57+
4258
namespace relay {
4359

4460
using Expr = tvm::RelayExpr;
@@ -97,8 +113,23 @@ class Constant : public Expr {
97113
TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span());
98114

99115
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
116+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode);
100117
};
101118

119+
/*!
120+
* \brief Returns the constant with given properties. A null property denotes 'no change'.
121+
* Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields.
122+
* \param constant The constant to copy
123+
* \param op_data The (optional) data for the copied constant. If none, ret_constant->data =
124+
* constant->data.
125+
* \param opt_virtual_device The (optional) virtual_device for the copied constant. If none,
126+
* ret_constant->virtual_device = constant->virtual_device.
127+
* \param opt_span The (optional) span for the copied constant. If none,
128+
* ret_constant->span = constant->span.
129+
*/
130+
Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data = {},
131+
Optional<VirtualDevice> opt_virtual_device = {}, Optional<Span> opt_span = {});
132+
102133
/*! \brief Tuple of multiple Exprs */
103134
class Tuple;
104135
/*! \brief Tuple container */

src/ir/expr.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
141141
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
142142
});
143143

144-
GlobalVar::GlobalVar(String name_hint, Type type) {
144+
GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
145145
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
146146
n->name_hint = std::move(name_hint);
147147
n->checked_type_ = std::move(type);
148+
n->span = std::move(span);
148149
data_ = std::move(n);
149150
}
150151

src/relay/ir/expr.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@
2727

2828
namespace tvm {
2929

30+
GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint, Optional<Type> opt_type,
31+
Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
32+
String name_hint = opt_name_hint.value_or(global_var->name_hint);
33+
Type type = opt_type.value_or(global_var->checked_type());
34+
VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device());
35+
Span span = opt_span.value_or(global_var->span);
36+
bool all_fields_unchanged =
37+
name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) &&
38+
virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span);
39+
if (!all_fields_unchanged) {
40+
GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite();
41+
cow_global_var_node->name_hint = name_hint;
42+
cow_global_var_node->checked_type_ = type;
43+
cow_global_var_node->virtual_device_ = virtual_device;
44+
cow_global_var_node->span = span;
45+
}
46+
47+
return global_var;
48+
}
49+
3050
VirtualDevice RelayExprNode::virtual_device() const {
3151
if (!this->virtual_device_.defined()) {
3252
// virtual_device_ should always be defined, unless we imported this node from JSON using an old
@@ -77,6 +97,25 @@ TensorType ConstantNode::tensor_type() const {
7797
return TensorType(shape, dtype);
7898
}
7999

100+
Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data,
101+
Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
102+
runtime::NDArray data = opt_data.value_or(constant->data);
103+
VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device());
104+
Span span = opt_span.value_or(constant->span);
105+
106+
bool all_fields_unchanged = data.same_as(constant->data) &&
107+
virtual_device.same_as(constant->virtual_device()) &&
108+
span.same_as(constant->span);
109+
110+
if (!all_fields_unchanged) {
111+
ConstantNode* cow_constant_node = constant.CopyOnWrite();
112+
cow_constant_node->data = data;
113+
cow_constant_node->virtual_device_ = virtual_device;
114+
cow_constant_node->span = span;
115+
}
116+
return constant;
117+
}
118+
80119
Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
81120
ObjectPtr<TupleNode> n = make_object<TupleNode>();
82121
n->fields = std::move(fields);

0 commit comments

Comments
 (0)