Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
use inspect to detect dispatchable func
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Oct 25, 2019
1 parent b4f2d78 commit 2fe9f81
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion contrib/tvmop/core/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def compute_dot(A, B):
return C


@defop(name="dot", target="cpu", dispatch=True, dtype=AllTypes)
@defop(name="dot", target="cpu", dtype=AllTypes)
def dot(dtype, fallback):
cfg = autotvm.get_config()
cfg.define_knob("bn", [64] if fallback else [64, 32])
Expand Down
13 changes: 7 additions & 6 deletions contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# coding: utf-8
import tvm
import inspect
from tvm import autotvm
from itertools import product

Expand Down Expand Up @@ -48,7 +49,7 @@ class OpDef:
without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
"""
def __init__(self, func, name, target, auto_broadcast, dispatch, **kwargs):
def __init__(self, func, name, target, auto_broadcast, **kwargs):
# construct the value combination of the arguments
# e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"]
# arg_combination = [
Expand All @@ -69,7 +70,7 @@ def __init__(self, func, name, target, auto_broadcast, dispatch, **kwargs):
self.name = name
self.target = target
self.auto_broadcast = auto_broadcast
self.dispatch = dispatch
self.dispatchable = 'fallback' in inspect.signature(self.func).parameters

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
Expand All @@ -79,7 +80,7 @@ def invoke_all(self):
if self.attrs_valid(**each_kwargs):
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs])
if self.dispatch is False:
if self.dispatchable is False:
sch, args = self.func(**each_kwargs)
yield sch, args, name
else:
Expand All @@ -105,7 +106,7 @@ def get_op_name(self, name, args):

def get_config_spaces(self):
for each_kwargs in self.arg_combination:
if self.attrs_valid(**each_kwargs) and self.dispatch is True:
if self.attrs_valid(**each_kwargs) and self.dispatchable is True:
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs])
config_space = autotvm.ConfigSpace()
Expand All @@ -120,7 +121,7 @@ def get_binds(self, args):
return None


def defop(name, target=None, auto_broadcast=False, dispatch=False, **kwargs):
def defop(name, target=None, auto_broadcast=False, **kwargs):
"""Decorator to define a tvm operator.
Parameters
----------
Expand All @@ -141,7 +142,7 @@ def defop(name, target=None, auto_broadcast=False, dispatch=False, **kwargs):
target = "cpu" if target is None else target

def _defop(func):
opdef = OpDef(func, name, target, auto_broadcast, dispatch, **kwargs)
opdef = OpDef(func, name, target, auto_broadcast, **kwargs)
__OP_DEF__.append(opdef)
return opdef
return _defop
Expand Down

0 comments on commit 2fe9f81

Please sign in to comment.