diff --git a/docs/api/python/hybrid.rst b/docs/api/python/hybrid.rst new file mode 100644 index 000000000000..3b4c598d82dd --- /dev/null +++ b/docs/api/python/hybrid.rst @@ -0,0 +1,15 @@ +tvm.hybrid +---------- +.. automodule:: tvm.hybrid + +.. autosummary:: + + tvm.hybrid.parse + tvm.hybrid.script + tvm.hybrid.popcount + tvm.hybrid.sigmoid + +.. autofunction:: tvm.hybrid.parse +.. autofunction:: tvm.hybrid.script +.. autofunction:: tvm.hybrid.popcount +.. autofunction:: tvm.hybrid.sigmoid diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index a6bed557dd3b..bab29b82f473 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -21,3 +21,4 @@ Python API dev topi nnvm/index + hybrid diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst new file mode 100644 index 000000000000..0af02a56e72c --- /dev/null +++ b/docs/dev/hybrid_script.rst @@ -0,0 +1,76 @@ +Hybrid Frontend Developer Guide +=============================== + +If you are a developer: + +1. who is trying writing some preliminary patterns that have not been supported by TVM yet, +maybe :ref:`hybrid-langref-label` is a better place for you. + +2. who wants to know the implementing details of this module, you are right here! + +Features +-------- + +Software emulation +~~~~~~~~~~~~~~~~~~ + +In software emulation, the most intresting thing is the decorator ``tvm.hybrid.script``. +This decorator helps 2 things: + +1. Importing runtime variables + +2. Overload the function according to the arguments passed + +Correct me if I am wrong: I believe that how 1. is implemented is dangerous, but I have no +choice. What I did is add those names into python dict ``func.__global__`` and after +the call to ``func`` is done, those names will be cleaned up. + +Overload is simple: the decorator checks the arguments' types and determines which function +should be actually called. + + +Backend Compilation +~~~~~~~~~~~~~~~~~~~ + +Compilation is a large module, you can see ``python/tvm/hybrid/var_decl.py`` and +``python/tvm/hybrid/parser.py`` for more details. The first stage determines the +usage, or more accurately the declaration of each variable and the second stage does +the actual IR generation. + +Attributes +~~~~~~~~~~ + +So far, ONLY tensors' `shape` attribute is supported. You can see ``visit_Subscript`` +in ``python/tvm/hybrid/parser.py`` for more details. This is a hacky solution, I just +check the attributes when subscript. + +Loops +~~~~~ + +In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. + + +.. note:: + + Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` + is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it + to HalideIR, we need to do ``start, extent = a, b - a`` + + +.. note:: + + In HalideIR those are enums, they are in passive form. + Here we use active form to annotate loops, because they are ready to run. + + +Variables +~~~~~~~~~ + +Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1. +It takes the first store of a variable as its declaration. + +Math intrinsics +~~~~~~~~~~~~~~~ +So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. +Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation +except ``popcount`` and ``sigmoid``. I implemented them manually. diff --git a/docs/dev/index.rst b/docs/dev/index.rst index 3fb052938689..f3ab322bfe53 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -10,3 +10,4 @@ In this part of documentation, we share the rationale for the specific choices m runtime nnvm_json_spec nnvm_overview + hybrid_script diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst new file mode 100644 index 000000000000..fdaed2b5be40 --- /dev/null +++ b/docs/langref/hybrid_script.rst @@ -0,0 +1,172 @@ +.. _hybrid-langref-label: + +Hybrid Frontend Language Reference +================================== + +Overview +-------- + +This hybrid frontend allows users to write preliminary versions of some idioms that yet have +been supported by TVM officially. + +Features +-------- + +Software Emulation +~~~~~~~~~~~~~~~~~~ + +Both software emulation and compilation are supported. To define a function, +you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: + +.. code-block:: python + + @tvm.hybrid.script + def outer_product(a, b, c): + for i in range(a.shape[0]): + for j in range(b.shape[0]): + c[i, j] = a[i] * b[j] + a = numpy.random.rand(100) + b = numpy.random.rand(99) + c = numpy.zeros((100, 99)) + outer_product(a, b, c) + +This decorator will import `Keywords`_ required spontaneously when software emulation. +After software emulation is done, the imported keywords will be cleaned up. Users do not need +worry about keyword conflict and pollution. + +Every element passed for software emulation in the argument list is either a python variable +or ``numpy`` numeric type. + +Backend Compilation +~~~~~~~~~~~~~~~~~~~ + +The current parse interface looks like: + +.. code-block:: python + + a = tvm.placeholder((100, ), name='a') + b = tvm.placeholder((99, ), name='b') + c = tvm.placeholder((100, 99), name='c') + tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function + +If we pass these tvm tensors to this function, it returns a op node: + +**Under construction, we are still deciding what kind of node should be returned.** + +.. code-block:: python + + a = tvm.placeholder((100, ), name='a') + b = tvm.placeholder((99, ), name='b') + c = tvm.placeholder((100, 99), name='c') + op = outer_product(a, b, c) # return the corresponding op node + +Tuning +~~~~~~ + +**Under construction, not truly supported yet.** + +Follow up the example above, you can use some tvm like interfaces to tune the code: + +.. code-block:: python + + sch = tvm.create_schedule(op) + jo, ji = sch.split(j, 4) + sch.vectorize(ji) + +``split``, ``reorder``, and loop_annotation will be supported! + +Loops +~~~~~ + +In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. + +Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize``, +these **4** keywords to annotate the corresponding types of for loops. +The the usage is roughly the same as Python standard ``range``. + +Variables +~~~~~~~~~ + +All the mutatable variables will be lowered to an array with size 1. +It regards the first store of a variable as its declaration. + +.. note:: + + Unlike conventional Python, in hybrid script, the declared variable + can only be used in the scope level it is declared. + + +.. note:: + + Currently, you can ONLY use basic-typed variables, i.e. the type of the + variable should be either ``float32``, or ``int32``. + +.. code-block:: python + + for i in range(5): + s = 0 # declaration, this s will be a 1-array in lowered IR + for j in range(5): + s += a[i, j] # do something with sum + b[i] = sum # you can still use sum in this level + a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python + b = (1, 2) # this has NOT been supported yet! + + +Attributes +~~~~~~~~~~ + +So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a +tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. + +.. code-block:: python + + x = a.shape[2] # OK! + for i in range(3): + for j in a.shape[i]: # BAD! i is not a constant! + # do something + + +Conditional Statement and Expression +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + if condition: + # do something + a = b if condition else c + +However, NO ``True`` and ``False`` keyword supported yet. + + +Math Intrinsics +~~~~~~~~~~~~~~~ + +So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, +``tanh``, ``power``, and ``popcount``, are supported. +No import is required, just as it is mentioned in `Software Emulation`_, just use it! + +Array Allocation +~~~~~~~~~~~~~~~~ + +**Under construction, this function will be supported later!** + +Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer. +The basic usage is roughly the same as a normal array. + + +Thread Bind +~~~~~~~~~~~ + + +You can also do loop-thread bind by writing code like this: + +.. code-block:: python + + for tx in bind("threadIdx.x", 100): + a[tx] = b[tx] + + +Keywords +~~~~~~~~ +- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` +- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` diff --git a/docs/langref/index.rst b/docs/langref/index.rst index dc51c3172c57..65f78d1d278b 100644 --- a/docs/langref/index.rst +++ b/docs/langref/index.rst @@ -2,3 +2,8 @@ Language Reference ================== This document provide references to embedded languages in TVM stack. + +.. toctree:: + :maxdepth: 2 + + hybrid_script diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 72b89af020d7..777654af6619 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -332,12 +332,20 @@ def lower(sch, lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - # normalize schedule first - sch = sch.normalize() + # Phase 0 - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - stmt = ir_pass.InjectPrefetch(stmt) + if isinstance(sch, schedule.Schedule): + # normalize schedule first + sch = sch.normalize() + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + stmt = ir_pass.InjectPrefetch(stmt) + else: + #So far there is no op for hybrid script, so a plain ir body is given + if not isinstance(sch, _stmt.Stmt): + raise ValueError("sch should be either a Schedule or a Stmt") + stmt = sch + for f in lower_phase0: stmt = f(stmt) # Phase 1 diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py new file mode 100644 index 000000000000..e0a39c562f0f --- /dev/null +++ b/python/tvm/hybrid/__init__.py @@ -0,0 +1,10 @@ +"""Hybrid Programming APIs of TVM Python Package. + +This package maps a subset of python to HalideIR so that: +1. Users can write some preliminary versions of the computation patterns +have not been supported yet and verify it across the real execution and +python semantic emulation. +2. Developers can build HalideIR by writing Python code. +""" + +from .api import script, parse diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py new file mode 100644 index 000000000000..bc5376509522 --- /dev/null +++ b/python/tvm/hybrid/api.py @@ -0,0 +1,46 @@ +"""APIs of lowering the Python subset to HalideIR""" +from __future__ import absolute_import as _abs + +import types +import decorator +from .parser import parse_python + +@decorator.decorator +def script(func, *args): + """If the arguments are tvm types, compile it to HalideIR. + O.W. return the python emulated result""" + from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types + if _is_tvm_arg_types(args): + return parse(func, args) + else: + intersect = _enter_hybrid_runtime(func) + func(*args) + _restore_runtime(func, intersect) + return func + + +def parse(func, args): + """Parse a subset of Python to HalideIR + + Parameters + ---------- + func : str or types.FunctionType + If it is a string, parse the source code + If it is a function, parse the function + + args : list of Buffer or Tensor or Var + The argument lists to the function. + Leave it None if no buffer is related to the function to be parsed + + Returns + ------- + root : Stmt + The result Halide IR and the parser class instance. + """ + from .util import _pruned_source + if isinstance(func, str): + src = func + else: + assert isinstance(func, types.FunctionType) + src = _pruned_source(func) + return parse_python(src, args) diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py new file mode 100644 index 000000000000..93517fef4d1d --- /dev/null +++ b/python/tvm/hybrid/intrin.py @@ -0,0 +1,112 @@ +"""Intrinsics of TVM-Python Hybrid Script for Python runtime""" + +import numpy +from ..stmt import For + +class _range(object): + """Base class of the loop ranges in hybrid script""" + def __init__(self, a, b=None): + if b is None: + self.low = 0 + self.ext = a + else: + self.low = a + self.ext = b + + def __iter__(self): + i = 0 + while i < self.ext: + yield i + self.low + i += 1 + + +class bind(_range): #pylint: disable=invalid-name + def __init__(self, tag, ext): + super(bind, self).__init__(ext) + self.tag = tag + + +unroll = vectorize = parallel = _range #pylint: disable=invalid-name + + +def allocate(shape, dtype='float32'): + """Allocate a buffer with given shape + + Parameters + ---------- + shape: Tuple + The shape of the tensor to be allocated + dtype: string + The data type of the tensor + + Returns + ------- + tensor: numpy.array + The tensor allocated + """ + return numpy.zeros(shape).astype(dtype) + + +def popcount(x): + """ + Count ones in the binary representation of number x + + Parameters + ---------- + x: Integer + The number to be counted + + Returns + ------- + cnt: Integer + The number of ones in the binary representation of number x + """ + cnt = 0 + while x: + x -= x & -x + cnt += 1 + return cnt + + +def sigmoid(x): + """ + Sigmoid function of x, aka 1/(1+exp(-x)). + + Parameters + ---------- + x: a real number + + Returns + ------- + res: a real number + The result of sigmoid function + """ + return 1 / (1 + numpy.exp(-x)) + + +HYBRID_GLOBALS = { + 'unroll' : unroll, + 'vectorize' : vectorize, + 'parallel' : parallel, + 'allocate' : allocate, + 'bind' : bind, + 'sqrt' : numpy.sqrt, + 'log' : numpy.log, + 'tanh' : numpy.tanh, + 'power' : numpy.power, + 'exp' : numpy.exp, + 'sigmoid' : sigmoid, + 'popcount' : popcount +} + + +LOOP_INTRIN = { + 'range' : For.Serial, + 'unroll' : For.Unrolled, + 'parallel' : For.Parallel, + 'vectorize': For.Vectorized, + 'bind' : None +} + + +MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount'] diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py new file mode 100644 index 000000000000..7d4c40e8c7e9 --- /dev/null +++ b/python/tvm/hybrid/parser.py @@ -0,0 +1,342 @@ +"""Hybrid Script Parser""" + +import ast +import operator +import sys +from .util import make_nop, make_const_true, make_range_one, halide_imm_types +from .intrin import LOOP_INTRIN, MATH_INTRIN +from .var_decl import determine_variable_usage +from ..api import thread_axis +from .. import expr as _expr +from .. import make as _make +from .. import intrin +from .. import api as _api +from .. import ir_pass as _ir_pass + +def list_to_block(visit, lst): + """Convert a list of Python IR nodes to HalideIR Block""" + lst = list(map(visit, lst)) + lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] + if not lst: + return make_nop() + if len(lst) == 1: + return lst[0] + body = lst[0] + for i in lst[1:]: + body = _make.Block(body, i) + return body + + +class HybridParser(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to HalideIR""" + + + _binop_maker = { + ast.Add : operator.add, + ast.Sub : operator.sub, + ast.Mult : operator.mul, + ast.Div : _make.Div, + ast.Mod : operator.mod, + ast.BitOr : operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt : operator.gt, + ast.GtE : operator.ge, + ast.Lt : operator.lt, + ast.LtE : operator.le, + ast.Eq : operator.eq, + ast.NotEq : operator.ne, + } + + + _unaryop_maker = { + ast.USub : operator.neg, + ast.Invert : operator.invert, + ast.Not : operator.not_ + } + + + def __init__(self, args, usage, func_name=None): + """ + Parameters + ---------- + args: A list of tvm.placeholder or tvm.var + Provided by the user, the argument list of the function to be lowered. + + usage: A dict of variables used in last in this function + Provided by last lower pass, which collects this information + + Returns + ------- + func_name: str + The name of the function to be lowered; if not provided, + the compiler will use the name in the AST + """ + self.args = args[:] + self.usage = usage.copy() + self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) + self.buffers = {} + self.loops_above = {} # State variable that indicates loop levels above the current node + self.var_consts = {} # Variables that are determined as readonly in previous stage + self.func_name = func_name # The name of the function to be lowered + self.iter_axis = [] + + + def wrap_up_realize(self, node, body): + """Wrap up all the variables which will no longer be used""" + for key, val in self.usage.items(): + if key in self.var_consts.keys(): + continue + _, scope, _ = val + if scope == node: + _buf = self.buffers[key] + _dtype = _buf.dtype + _one = make_range_one() + _true = make_const_true() + body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body) + return body + + + def _check_id_a_buffer(self, s): + if s not in self._args.keys(): + raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) + + + #pylint: disable=invalid-name, missing-docstring + def visit_Module(self, node): + if len(node.body) != 1: + raise ValueError("Only one-function source code can be fed to this parser!") + return self.visit(node.body[0]) + + + def visit_FunctionDef(self, node): + if len(node.args.args) != len(self.args): + raise ValueError("The number of arguments passed to the function\ + should be the same as it is defined!") + for idx, arg in enumerate(node.args.args): + _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible + self._args[getattr(arg, _attr)] = self.args[idx] + res = list_to_block(self.visit, node.body) + res = self.wrap_up_realize(node, res) + if self.func_name is None: + self.func_name = node.name + return res + + + def visit_Expr(self, node): + return self.visit(node.value) + + + def visit_Name(self, node): + _id = node.id + if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var): + return self._args[_id] + elif _id in self.loops_above.keys(): + return self.loops_above[_id] + if _id in self._args.keys(): + raise ValueError("This id %s should be handled in visit_Subscript!" % _id) + if _id not in self.usage.keys(): + raise ValueError("This id %s is expected to be a defined variable!" % _id) + # Buffer + if _id in self.buffers.keys(): + _buf = self.buffers[_id] + return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) + # Compilation time constant + if _id not in self.var_consts.keys(): + raise ValueError("This id %s is expected to a compilation time constant!" % _id) + return self.var_consts[_id] + + + def visit_Num(self, node): + return _api.const(node.n) + + + def visit_Assign(self, node): + if len(node.targets) != 1: + raise ValueError("So far only one-valued assignment is supported!") + lhs = node.targets[0] + rhs = _ir_pass.Simplify(self.visit(node.value)) + if isinstance(lhs, ast.Name): + #TODO: support defined intermediate buffer later + lhs_ = lhs + lhs = lhs.id + if lhs in self.loops_above.keys(): + raise ValueError("You CAN NEVER overwrite a loop variable!") + decl, _, rw = self.usage[lhs] + if decl == lhs_: + if lhs in self.var_consts.keys(): + raise ValueError("BUG: A constant cannot be overwritten!") + if lhs in self.buffers.keys(): + raise ValueError("BUG: This value should not be defined before this point!") + if isinstance(rhs, halide_imm_types) and ast.Store not in rw: + self.var_consts[lhs] = rhs + else: + self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) + if lhs in self.var_consts.keys(): + return make_nop() + else: + if lhs not in self.buffers.keys(): + raise ValueError("BUG: This value should be defined before!") + return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) + else: + lhs = self.visit(lhs) + if not isinstance(lhs, _expr.Call): + raise ValueError("An array access's LHS is expected to be a expr.Call!") + #TODO: support slice later + self._check_id_a_buffer(lhs.name) + return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args) + + + def visit_Index(self, node): + if isinstance(node.value, ast.Tuple): + return [self.visit(i) for i in node.value.elts] + return [self.visit(node.value)] + + + def visit_Subscript(self, node): + args = self.visit(node.slice) + if isinstance(node.value, ast.Name): + array = node.value.id + self._check_id_a_buffer(array) + _buf = self._args[array] + return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) + elif isinstance(node.value, ast.Attribute): + if not isinstance(node.value.value, ast.Name): + raise ValueError("The root of array access is expect to be a id!") + if node.value.attr != "shape": + raise ValueError("Attribute access so far only 'shape' is supported!") + if len(args) != 1: + raise ValueError("For 'shape' access the argument should be only one!") + args = args[0] + #TODO: maybe support non-constant value later? + if not isinstance(args, (_expr.IntImm, _expr.UIntImm)): + raise ValueError("So far only constant shape access supported!") + self._check_id_a_buffer(node.value.value.id) + return self._args[node.value.value.id].shape[args.value] + else: + raise ValueError("Not supported yet!") + + + def visit_With(self, node): + if sys.version_info[0] < 3: + context = node.context_expr + option = node.optional_vars + else: + if len(node.items) != 1: + raise ValueError("Only one with element is supported so far!") + context = node.items[0].context_expr + option = node.items[0].optional_vars + if not isinstance(context, ast.Call): + raise ValueError("The object must be a Python function call!") + if not isinstance(option, ast.Name): + raise ValueError("The object after 'as' must be an id!") + self.annotation[option.id] = context.func.id + return list_to_block(self.visit, node.body) + + + def visit_If(self, node): + cond = self.visit(node.test) + if_body = list_to_block(self.visit, node.body) + if node.orelse: + else_body = list_to_block(self.visit, node.orelse) + else: + else_body = make_nop() + return _make.IfThenElse(cond, if_body, else_body) + + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if_body = self.visit(node.body) + else_body = self.visit(node.orelse) + return _make.Select(cond, if_body, else_body) + + + def visit_Compare(self, node): + lhs = self.visit(node.left) + if len(node.ops) != 1: + raise ValueError("Only one compare op is supported!") + if len(node.comparators) != 1: + raise ValueError("Only one comparator is supported!") + rhs = self.visit(node.comparators[0]) + return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs) + + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + return HybridParser._unaryop_maker[type(node.op)](operand) + + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + return HybridParser._binop_maker[type(node.op)](lhs, rhs) + + + def visit_Call(self, node): + # Yet, no function pointer supported + if not isinstance(node.func, ast.Name): + raise ValueError("Only id-function function call is supported so far!") + func_id = node.func.id + n = len(node.args) + if func_id in LOOP_INTRIN.keys() and func_id != 'bind': + if n == 1: + low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0]) + else: + if n != 2: + raise ValueError("A loop intrinsic should only have 1 or 2 arguments!") + low, ext = self.visit(node.args[0]), self.visit(node.args[1]) + if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): + ext = ext - low + for_type = LOOP_INTRIN[func_id] + iter_var = None + return iter_var, low, ext, for_type + elif func_id == 'bind': + if n != 2: + raise ValueError("A loop bind should only have 2 arguments!") + if not isinstance(node.args[0], ast.Str): + raise ValueError("A loop bind's first argument should be a string!") + _vn = node.args[0].s + iter_var = thread_axis(node.args[0].s) + low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1]) + for_type = None + return iter_var, low, ext, for_type + elif func_id in MATH_INTRIN: + return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) + elif func_id == 'allocate': + #TODO: Support it later! + return make_nop() + else: + raise ValueError("Function call not supported yet!") + + + def visit_For(self, node): + iter_var, low, ext, for_type = self.visit(node.iter) + if not isinstance(node.target, ast.Name): + raise ValueError("The loop iterator should be a variable!") + _name = node.target.id + if iter_var is None: + if for_type is None: + raise ValueError("The loop bind function parse error!") + iter_var = _api.var(_name) + self.loops_above[_name] = iter_var + else: + if for_type is not None: + raise ValueError("The loop iterating function parse error!") + self.loops_above[_name] = iter_var.var + _body = list_to_block(self.visit, node.body) + _body = self.wrap_up_realize(node, _body) + if for_type is None: + res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) + else: + res = _make.For(iter_var, low, ext, for_type, 0, _body) + self.loops_above.pop(_name) + return res + + +def parse_python(src, args): + """The helper function of calling the AST visitor""" + root = ast.parse(src) + var_usage = determine_variable_usage(root, args) + parser = HybridParser(args, var_usage) + halide_ir = parser.visit(root) + return halide_ir diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py new file mode 100644 index 000000000000..8a5f4a62768d --- /dev/null +++ b/python/tvm/hybrid/util.py @@ -0,0 +1,76 @@ +"""Internal utilities for parsing Python subset to HalideIR""" + +import inspect +import numpy +from .intrin import HYBRID_GLOBALS +from .._ffi.base import numeric_types +from .. import api as _api +from .. import make as _make +from .. import expr as _expr +from ..tensor import Tensor + + +#pylint: disable=invalid-name +np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) +tvm_arg_types = (Tensor, _expr.Var) +halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) + + +# Useful constants. In avoid of runtime dependences, we use function calls to return them. +def make_nop(): + """Returns a 'no operation' node in HalideIR.""" + return _make.Evaluate(_api.const(0, dtype='int32')) + + +def make_range_one(): + """Returns a [0, 1] range node in HalideIR.""" + return _make.range_by_min_extent(0, 1) + + +def make_const_true(): + """Returns a constant True node in HalideIR.""" + return _api.convert(True) + + +def _pruned_source(func): + """Prune source code's extra leading spaces""" + lines = inspect.getsource(func).split('\n') + leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) + lines = [line[leading_space:] for line in lines] + return '\n'.join(lines) + + +def _is_tvm_arg_types(args): + """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. + If neither is true, raise a value error.""" + if isinstance(args[0], tvm_arg_types): + for elem in args[1:]: + if not isinstance(elem, tvm_arg_types): + raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem))) + return True + if not isinstance(args[0], np_arg_types): + raise ValueError("Expect a numpy type but % get!" % str(type(args[0]))) + for elem in args[1:]: + if not isinstance(elem, np_arg_types): + raise ValueError("Expect a numpy type but % get!" % str(type(elem))) + return False + + +def _enter_hybrid_runtime(func): + """Put hybrid runtime variables into the global scope""" + _globals = func.__globals__ + intersect = [] + for elem in list(HYBRID_GLOBALS.keys()): + if elem in _globals.keys(): + intersect.append((elem, _globals[elem])) + _globals[elem] = HYBRID_GLOBALS[elem] + return intersect + + +def _restore_runtime(func, intersect): + """Rollback the modification caused by hybrid runtime""" + _globals = func.__globals__ + for elem in list(HYBRID_GLOBALS.keys()): + _globals.pop(elem) + for k, v in intersect: + _globals[k] = v diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py new file mode 100644 index 000000000000..940b8c088df3 --- /dev/null +++ b/python/tvm/hybrid/var_decl.py @@ -0,0 +1,76 @@ +"""Determines the declaration, r/w status, and last use of each variable""" + +import ast +import sys +from .intrin import HYBRID_GLOBALS + + +class PyVariableUsage(ast.NodeVisitor): + """The vistor class to determine the declaration, r/w status, and last use of each variable""" + #pylint: disable=invalid-name + #pylint: disable=missing-docstring + def __init__(self, args): + self.status = {} + self.scope_level = [] + self._args = {} + self.args = args + + + def visit_FunctionDef(self, node): + self.scope_level.append(node) + if len(node.args.args) != len(self.args): + raise ValueError('#arguments passed should be the same as #arguments defined') + for idx, arg in enumerate(node.args.args): + _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible + self._args[getattr(arg, _attr)] = self.args[idx] + for i in node.body: + self.visit(i) + + + def visit_For(self, node): + if not isinstance(node.target, ast.Name): + raise ValueError("For's iterator should be an id") + self.visit(node.iter) + self.scope_level.append(node) + for i in node.body: + self.visit(i) + self.scope_level.pop() + + + def visit_Call(self, node): + #No function pointer supported so far + if not isinstance(node.func, ast.Name): + raise ValueError("Function call should be an id") + if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range': + raise ValueError("Function call id not in intrinsics' list") + for elem in node.args: + self.visit(elem) + + + def visit_Name(self, node): + # If it is from the argument list or loop variable, we do not worry about it! + if node.id in self._args.keys(): + return + fors = [loop.target.id for loop in self.scope_level if isinstance(loop, ast.For)] + if node.id in fors: + return + # The loop variable cannot be overwritten when iteration + if isinstance(node.ctx, ast.Store) and node.id in fors: + raise ValueError("Iter var cannot be overwritten") + + if node.id not in self.status.keys(): + if not isinstance(node.ctx, ast.Store): + raise ValueError('In Python, "first store" indicates "declaration"') + self.status[node.id] = (node, self.scope_level[-1], set()) + else: + decl, loop, usage = self.status[node.id] + loop = self.scope_level[-1] + usage.add(type(node.ctx)) + self.status[node.id] = (decl, loop, usage) + + +def determine_variable_usage(root, args): + """The helper function for calling the dedicated visitor.""" + visitor = PyVariableUsage(args) + visitor.visit(root) + return visitor.status diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py new file mode 100644 index 000000000000..fda4f52c1f19 --- /dev/null +++ b/tests/python/unittest/test_hybrid_script.py @@ -0,0 +1,287 @@ +import tvm, inspect, sys, traceback, numpy +from tvm.hybrid import script +from tvm.hybrid.intrin import HYBRID_GLOBALS + +@script +def outer_product(n, m, a, b, c): + for i in range(n): + for j in range(m): + c[i, j] = a[i] * b[j] + +#Test global function +#Test bridge between frontend and backend +def test_outer_product(): + n = tvm.var('n') + m = tvm.var('m') + a = tvm.placeholder((n, ), name='a') + b = tvm.placeholder((m, ), name='b') + c = tvm.placeholder((n, m), name='c') + ir = outer_product(n, m, a, b, c) + #Check for i in (0, n) + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i' + assert ir.min.value == 0 + assert ir.extent.name == 'n' + ibody = ir.body + assert isinstance(ibody, tvm.stmt.For) + #Check for j in (0, m) + assert ibody.loop_var.name == 'j' + assert ibody.min.value == 0 + assert ibody.extent.name == 'm' + #Check loop body + jbody = ibody.body + assert isinstance(jbody, tvm.stmt.Provide) + assert jbody.func.name == 'c' + assert len(jbody.args) == 2 + assert jbody.args[0].name == 'i' + assert jbody.args[1].name == 'j' + assert isinstance(jbody.value, tvm.expr.Mul) + mul = jbody.value + assert isinstance(mul.a, tvm.expr.Call) + assert mul.a.name == 'a' + assert mul.b.name == 'b' + + func = tvm.lower(ir, [n, m, a, b, c]) + func = tvm.build(func) + + _n = 999 + _m = 1001 + _a = numpy.random.rand(_n).astype('float32') + _b = numpy.random.rand(_m).astype('float32') + c_python = numpy.zeros((_n, _m), dtype='float32') + outer_product(_n, _m, _a, _b, c_python) + + tvm_a = tvm.ndarray.array(_a) + tvm_b = tvm.ndarray.array(_b) + tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32')) + func(_n, _m, tvm_a, tvm_b, tvm_c) + numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5) + for key, _ in HYBRID_GLOBALS.items(): + assert key not in globals().keys() + assert key not in outer_product.__globals__.keys() + +#Test local function +#Test allocation of local variable +def test_fanout(): + @script + def fanout(n, a, b): + three = 3.0 + for i in range(a.shape[0] - 3): + sigma = 0.0 + for j in range(3): + sigma = sigma + a[i + j] + sigma = sigma / three + b[i] = sigma + + n = tvm.var('n') + a = tvm.placeholder((n, ), name='a') + b = tvm.placeholder((n-3, ), name='b') + ir = fanout(n, a, b) + + #Check for i in (0, n-3) + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i' + assert ir.min.value == 0 + assert tvm.ir_pass.Equal(ir.extent, n - 3) + #Check loopbody + ibody = ir.body + assert isinstance(ibody, tvm.stmt.Realize) + assert ibody.bounds[0].min.value == 0 + assert ibody.bounds[0].extent.value == 1 + assert ibody.func.name == 'sigma' + #Check i loop body + rbody = ibody.body + assert isinstance(rbody.first, tvm.stmt.Provide) + assert rbody.first.func.name == 'sigma' + assert len(rbody.first.args) == 1 + assert rbody.first.args[0].value == 0 + #Check fanout loop + jloop = rbody.rest.first + assert jloop.loop_var.name == 'j' + assert jloop.min.value == 0 + assert jloop.extent.value == 3 + jbody = jloop.body + assert isinstance(jbody, tvm.stmt.Provide) + assert len(jbody.args) == 1 + assert jbody.args[0].value == 0 + assert jbody.func.name == 'sigma' + assert isinstance(jbody.value, tvm.expr.Add) + value = jbody.value + assert isinstance(value.a, tvm.expr.Call) + assert value.a.name == 'sigma' + assert len(value.a.args) == 1 + assert value.a.args[0].value == 0 + assert value.b.name == 'a' + assert len(value.b.args) == 1 + assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) + divide= rbody.rest.rest.first + assert isinstance(divide, tvm.stmt.Provide) + assert len(divide.args) == 1 + assert divide.args[0].value == 0 + value = divide.value + assert isinstance(value, tvm.expr.Mul) + assert value.a.name == 'sigma' + assert len(value.a.args) == 1 + assert value.a.args[0].value == 0 + assert abs(value.b.value - (1 / 3.0)) < 1e-5 + write = rbody.rest.rest.rest + assert isinstance(write, tvm.stmt.Provide) + assert write.func.name == 'b' + assert write.value.name == 'sigma' + assert len(write.value.args) == 1 + assert write.value.args[0].value == 0 + +@script +def failure(): + for i in range(1, 100): + i = 0 + +def test_failure(): + try: + tvm.hybrid.parse(failure, []) + except IOError as err: + assert sys.version_info[0] == 2 + print('[Warning] Python2 cannot do the failure case because "%s"' % str(err)) + except Exception as err: + assert str(err) == 'You CAN NEVER overwrite a loop variable!' + + +def test_looptype(): + @script + def looptype(a): + for i in parallel(6): + a[i] = i + for j in vectorize(6): + a[j] = j + for k in unroll(6): + a[k] = k + a = tvm.placeholder((6, ), name='a') + ir = looptype(a) + iloop = ir.first + jloop = ir.rest.first + kloop = ir.rest.rest + assert iloop.for_type == tvm.stmt.For.Parallel + assert jloop.for_type == tvm.stmt.For.Vectorized + assert kloop.for_type == tvm.stmt.For.Unrolled + +def test_if(): + @script + def if_then_else(a, b): + for i in range(10): + if i % 2 == 0: + a[i] = -1 + else: + a[i] = 1 + for i in unroll(10): + b[i] = -1 if i % 2 == 0 else 1 + + a = tvm.placeholder((10, ), dtype='int32', name='a') + b = tvm.placeholder((10, ), dtype='int32', name='b') + ir = if_then_else(a, b) + func = tvm.lower(ir, [a, b]) + func = tvm.build(func) + assert func + + _a = numpy.zeros((10, ), dtype = 'int32') + _b = numpy.zeros((10, ), dtype = 'int32') + if_then_else(_a, _b) + + tvm_a = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32')) + tvm_b = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32')) + func(tvm_a, tvm_b) + + numpy.testing.assert_allclose(tvm_a.asnumpy(), _a, rtol=1e-5) + numpy.testing.assert_allclose(tvm_b.asnumpy(), _b, rtol=1e-5) + numpy.testing.assert_allclose(tvm_a.asnumpy(), tvm_b.asnumpy(), rtol=1e-5) + +def test_bind(): + if not tvm.gpu(0).exist: + print('No GPU found! Skip this test!') + return + @script + def vec_add(a, b, c): + for tx in bind('threadIdx.x', 1000): + c[tx] = b[tx] + c[tx] + + a = tvm.placeholder((1000, ), dtype='float32', name='a') + b = tvm.placeholder((1000, ), dtype='float32', name='b') + c = tvm.placeholder((1000, ), dtype='float32', name='c') + ir = vec_add(a, b, c) + + func = tvm.lower(ir, [a, b, c]) + func = tvm.build(func, target = 'cuda') + + _a = numpy.random.rand(1000).astype('float32') + _b = numpy.random.rand(1000).astype('float32') + _c = numpy.zeros((1000, ), dtype = 'float32') + + + tvm_a = tvm.ndarray.array(_a, tvm.gpu(0)) + tvm_b = tvm.ndarray.array(_b, tvm.gpu(0)) + tvm_c = tvm.ndarray.array(_c, tvm.gpu(0)) + + func(tvm_a, tvm_b, tvm_c) + vec_add(_a, _b, _c) + + numpy.testing.assert_allclose(_c, tvm_c.asnumpy(), rtol=1e-5) + +def test_math_intrin(): + @script + def intrin_real(a): + a[0] = sqrt(a[0]) + a[1] = log(a[1]) + a[2] = exp(a[2]) + a[3] = sigmoid(a[3]) + a[4] = power(a[4], a[5]) + a[5] = tanh(a[5]) + + a6 = tvm.placeholder((6, ), dtype='float32', name='a') + ir = intrin_real(a6) + func = tvm.build(tvm.lower(ir, [a6])) + assert func + a = numpy.arange(2, 8).astype('float32') + tvm_a = tvm.ndarray.array(a) + func(tvm_a) + intrin_real(a) + numpy.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5) + + @script + def intrin_int(a): + a[0] = popcount(a[0]) + + a1 = tvm.placeholder((1, ), dtype='int32') + ir = intrin_int(a1) + func = tvm.build(tvm.lower(ir, [a1])) + assert func + a = numpy.array([1234567890]).astype('int32') + tvm_a = tvm.ndarray.array(a) + intrin_int(a) + func(tvm_a) + assert tvm_a.asnumpy()[0] == a[0] + +def test_allocate_buffer(): + def blur(a): + for i in serail(32): + h_blur = allocate((4, 36)) + for j in serail(4): + for k in serail(36): + s = allocate((1, ), 'float32') + for dj in serail(4): + s[0] = s[0] + a[i, j + dj] + h_blur[j, k] = s[0] / 4. + for j in serail(32): + s = 0. + for di in serail(4): + s = s + h_blur[di, j] + h_blur[i, j] = s / 4. + + +if __name__ == "__main__": + test_outer_product() + test_fanout() + test_failure() + test_looptype() + test_if() + test_bind() + test_math_intrin() +