From 3703d37a62b05f2d7581508cc4ff90a512fed80c Mon Sep 17 00:00:00 2001 From: Fan Date: Mon, 12 Aug 2019 11:10:08 +0800 Subject: [PATCH] tvm infra for op attrs --- contrib/tvmop/compile.py | 4 ++-- contrib/tvmop/opdef.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/contrib/tvmop/compile.py b/contrib/tvmop/compile.py index 94274fe4142a..e6af0a276560 100644 --- a/contrib/tvmop/compile.py +++ b/contrib/tvmop/compile.py @@ -44,11 +44,11 @@ def get_target(device): # TODO: attach instruction features to the library, e.g., avx-512, etc. for operator_def in __OP_DEF__: - for sch, args in operator_def.invoke_all(): + for sch, args, name in operator_def.invoke_all(): if tvm.module.enabled(get_target(operator_def.target)): func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda func_lower = tvm.lower(sch, args, - name=operator_def.get_op_name(args), + name=name, binds=operator_def.get_binds(args)) func_list.append(func_lower) diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py index c65824588047..32d1832d13dd 100644 --- a/contrib/tvmop/opdef.py +++ b/contrib/tvmop/opdef.py @@ -56,6 +56,8 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs): # {"ldtype": "int32", "rdtype": "float16"}, # {"ldtype": "int32", "rdtype": "int16"}, # ] + self.attrs = kwargs.pop('attrs', []) + self.attrs_valid = kwargs.pop('attrs_valid', lambda **kwargs: True) args = [k for k in kwargs] values = [kwargs[k] if isinstance(kwargs[k], (list, tuple)) else [kwargs[k]] for k in args] @@ -72,10 +74,12 @@ def __call__(self, *args, **kwargs): def invoke_all(self): for each_kwargs in self.arg_combination: - yield self.func(**each_kwargs) - - def get_op_name(self, args): - return self.name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args]) + if (self.attrs_valid(**each_kwargs)): + sch, args = self.func(**each_kwargs) + name = self.name \ + + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \ + + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args]) + yield sch, args, name def get_binds(self, args): if self.auto_broadcast: