Skip to content

Commit

Permalink
[Relay tests] AlterOpLayout - Temporary attr update
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Nov 17, 2019
1 parent 5d66e7a commit 6d727b6
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 32 deletions.
6 changes: 6 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ class OpRegistry {
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);

/*!
* \brief Resets an attr of the registry.
* \param attr_name The name of the attribute.
*/
inline void reset_attr(const std::string& attr_name);

// set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ def set_attr(self, attr_name, value, plevel=10):
"""
_OpSetAttr(self, attr_name, value, plevel)

def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_OpResetAttr(self, attr_name)


def get(op_name):
"""Get the Op for a given name
Expand Down
20 changes: 20 additions & 0 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ const bool Op::HasGenericAttr(const std::string& key) {
return true;
}

// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
return;
}
op_map->data_.clear();
}

void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
Expand Down Expand Up @@ -152,6 +163,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr")
reg.set_attr(attr_name, value, plevel);
});

TVM_REGISTER_API("relay.op._OpResetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
reg.reset_attr(attr_name);
});

TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
Expand Down
Loading

0 comments on commit 6d727b6

Please sign in to comment.