Skip to content

Commit 9a99fc8

Browse files
authored
[Utils] Allow classmethod and staticmethod in TVMDerivedObject (#14249)
Instance methods that exist in the user-defined class but not in the TVM base are forward using `__getattr__`. However, this is only applied for attribute look of instances, and doesn't apply for attribute lookup on the class object itself, such as when calling a classmethod or staticmethod. This commit exposes class methods and static methods in the wrapper class, if they are defined in the user-defined subclass.
1 parent 06fabe4 commit 9a99fc8

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

python/tvm/meta_schedule/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def __setattr__(self, name, value):
128128
TVMDerivedObject.__name__ = cls.__name__
129129
TVMDerivedObject.__doc__ = cls.__doc__
130130
TVMDerivedObject.__module__ = cls.__module__
131+
for key, value in cls.__dict__.items():
132+
if isinstance(value, (classmethod, staticmethod)):
133+
setattr(TVMDerivedObject, key, value)
131134
return TVMDerivedObject
132135

133136

tests/python/unittest/test_meta_schedule_post_order_apply.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,26 @@ def _get_sch(filter_fn):
404404
assert len(schs) == 8
405405

406406

407+
def test_meta_schedule_derived_object():
408+
@derived_object
409+
class RemoveBlock(PyScheduleRule):
410+
@classmethod
411+
def class_construct(cls):
412+
return cls()
413+
414+
@staticmethod
415+
def static_construct():
416+
return RemoveBlock()
417+
418+
inst_by_init = RemoveBlock()
419+
assert isinstance(inst_by_init, RemoveBlock)
420+
421+
inst_by_classmethod = RemoveBlock.class_construct()
422+
assert isinstance(inst_by_classmethod, RemoveBlock)
423+
424+
inst_by_staticmethod = RemoveBlock.static_construct()
425+
assert isinstance(inst_by_staticmethod, RemoveBlock)
426+
427+
407428
if __name__ == "__main__":
408429
tvm.testing.main()

0 commit comments

Comments
 (0)