Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
tvm infra for op attrs (#15854)
Browse files Browse the repository at this point in the history
  • Loading branch information
hzfan authored and yzhliu committed Aug 13, 2019
1 parent 05f3ae1 commit 67daae7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_target(device):

# TODO: attach instruction features to the library, e.g., avx-512, etc.
for operator_def in __OP_DEF__:
for sch, args in operator_def.invoke_all():
for sch, args, name in operator_def.invoke_all():
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
func_lower = tvm.lower(sch, args,
name=operator_def.get_op_name(args),
name=name,
binds=operator_def.get_binds(args))
func_list.append(func_lower)

Expand Down
12 changes: 8 additions & 4 deletions contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs):
# {"ldtype": "int32", "rdtype": "float16"},
# {"ldtype": "int32", "rdtype": "int16"},
# ]
self.attrs = kwargs.pop('attrs', [])
self.attrs_valid = kwargs.pop('attrs_valid', lambda **kwargs: True)
args = [k for k in kwargs]
values = [kwargs[k] if isinstance(kwargs[k], (list, tuple)) else [kwargs[k]]
for k in args]
Expand All @@ -72,10 +74,12 @@ def __call__(self, *args, **kwargs):

def invoke_all(self):
for each_kwargs in self.arg_combination:
yield self.func(**each_kwargs)

def get_op_name(self, args):
return self.name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
if (self.attrs_valid(**each_kwargs)):
sch, args = self.func(**each_kwargs)
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
yield sch, args, name

def get_binds(self, args):
if self.auto_broadcast:
Expand Down

0 comments on commit 67daae7

Please sign in to comment.