Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhance tvm infra for op attrs #15854

Merged
merged 1 commit into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_target(device):

# TODO: attach instruction features to the library, e.g., avx-512, etc.
for operator_def in __OP_DEF__:
for sch, args in operator_def.invoke_all():
for sch, args, name in operator_def.invoke_all():
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
func_lower = tvm.lower(sch, args,
name=operator_def.get_op_name(args),
name=name,
binds=operator_def.get_binds(args))
func_list.append(func_lower)

Expand Down
12 changes: 8 additions & 4 deletions contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs):
# {"ldtype": "int32", "rdtype": "float16"},
# {"ldtype": "int32", "rdtype": "int16"},
# ]
self.attrs = kwargs.pop('attrs', [])
self.attrs_valid = kwargs.pop('attrs_valid', lambda **kwargs: True)
args = [k for k in kwargs]
values = [kwargs[k] if isinstance(kwargs[k], (list, tuple)) else [kwargs[k]]
for k in args]
Expand All @@ -72,10 +74,12 @@ def __call__(self, *args, **kwargs):

def invoke_all(self):
for each_kwargs in self.arg_combination:
yield self.func(**each_kwargs)

def get_op_name(self, args):
return self.name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
if (self.attrs_valid(**each_kwargs)):
sch, args = self.func(**each_kwargs)
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
yield sch, args, name

def get_binds(self, args):
if self.auto_broadcast:
Expand Down