Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
ae28f58
attempt of c->python
May 7, 2021
1258223
name change
May 10, 2021
f57ae0e
build update
May 12, 2021
8a4f40e
import fix
May 19, 2021
f7e8bbd
build working, import still needs work
May 20, 2021
e4267fb
returning null binds in driver_api.cc
May 21, 2021
fa322f6
tests pass woohoo00!
May 21, 2021
51d85fb
merge to recent tvm
May 21, 2021
8e30116
black'd _ffi_api.py
May 22, 2021
ca11f46
remove simple_mode arg from lower in build_module.py for now
May 24, 2021
b117478
attempt add simple_mode arg in lower c++
May 24, 2021
41640a1
lower now can take in schedule or IRModule
May 24, 2021
724db08
add simple_mode parameter in c++
May 24, 2021
47ea819
reformat for lint
May 24, 2021
c01ca1e
include header details for lint
May 24, 2021
49d6c7f
ast lhs and rhs not matching, refactoring driver_api.cc
May 25, 2021
5e43b18
Added user-defined passes, still failing some tests because python le…
electriclilies May 26, 2021
59534cb
tests are green
electriclilies May 27, 2021
5cf0f78
Split lower api into 3
electriclilies May 28, 2021
b10d6ad
got rid of legacy lower
electriclilies May 28, 2021
5d5c068
renamed lower funcs
electriclilies May 28, 2021
f2596c5
remove python get_binds and rename flags
electriclilies May 28, 2021
d6be36e
fix typo
electriclilies May 28, 2021
06532d3
clean up doc and formatting
electriclilies May 28, 2021
4d7a602
fix typos
electriclilies May 28, 2021
2ed255d
fix lint
electriclilies May 28, 2021
cec3456
fix calls to lower in build_module_test
electriclilies May 28, 2021
b92dfe3
change enable_loop_partition back to simple_mode for consistency with…
electriclilies May 28, 2021
0d69397
clang format
electriclilies May 28, 2021
30ad407
fix typo
electriclilies May 28, 2021
51bcd3f
retrigger
electriclilies May 28, 2021
eddec7b
fix calls to ffi lower
electriclilies May 28, 2021
1bb8b9d
Add get binds to the FFI
electriclilies Jun 1, 2021
e932400
Merge branch 'main' into lower_build
electriclilies Jun 1, 2021
4b8a529
fix black
electriclilies Jun 1, 2021
6115f3c
comment out SchedulePostProcRewriteForTensorCore
electriclilies Jun 1, 2021
b9c8cb6
DELETE schedule postproc rewrite for tensorcore
electriclilies Jun 1, 2021
09d7806
Fix call fo lower_primfunc
electriclilies Jun 1, 2021
8bfc97e
Clean up comments
electriclilies Jun 1, 2021
566de68
change return of ffi get_binds
electriclilies Jun 1, 2021
cc9458e
Fix off by1
electriclilies Jun 2, 2021
4471a01
Merge branch 'main' of https://github.com/apache/incubator-tvm into l…
electriclilies Jun 2, 2021
ca4683d
respond to tristan's comments
electriclilies Jun 7, 2021
00fcd7a
update driver_api.h docs
electriclilies Jun 7, 2021
7a2b404
Apply suggestions from code review
Jun 7, 2021
8988368
remove relay.backend.lower
electriclilies Jun 7, 2021
a460f7e
Merge branch 'lower_build' of github.com:CircleSpin/tvm into lower_build
electriclilies Jun 7, 2021
38a121b
Respond to feedback
electriclilies Jun 8, 2021
55e90cf
remove 2nd get binds impl
electriclilies Jun 8, 2021
d3e18c2
fix lint
electriclilies Jun 8, 2021
6e5a2ff
Update src/driver/driver_api.cc
Jun 8, 2021
4933a5e
clang format
electriclilies Jun 9, 2021
b42fb19
update doc
electriclilies Jun 9, 2021
4c7049d
Merge branch 'lower_build' of github.com:CircleSpin/tvm into lower_build
electriclilies Jun 9, 2021
e462dab
fix typo
electriclilies Jun 9, 2021
d067a67
black
electriclilies Jun 9, 2021
794c606
fix pass ctx
electriclilies Jun 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/support/with.h>
#include <tvm/target/target.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/function.h>

#include <string>
#include <unordered_map>
Expand All @@ -42,17 +43,68 @@
#include <vector>

namespace tvm {

/*!
* \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList)
* \param mod The IRmodule to lower
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);

/*!
* \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list
* defined in CreatePassList)
* \param func The PrimFunc to lower
* \param name The name of the lowered function.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
bool simple_mode = false);

/*!
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
* the lowering passes defined in CreatePassList.
* \param sch The TE schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param simple_mode Disables the loop partition pass. Defaults to false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a special flag to disable the loop partition pass. Shouldn't the existing PassContext infrastructure handle this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure. This logic was already in the code base, so I just duplicated it and made it more explicit. (For context, when I started this refactor, Jared said he wanted to get it in quickly so I should just try to naively duplicate the existing logic, but he hasn't reviewed it yet so I'm not sure what the timeline is now).

I can try to remove it, but I'm not sure what the best way to go about this is since there are a few tests that call lower directly..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably fine for now.

* \return The result module.
*/
TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);

TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);

/*!
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
* the lowering passes defined in CreatePassList.
* \param sch The TE schedule to lower.
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);

/*!
* \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want
* to apply lowering passes as well, use LowerSchedule.
* \param sch The schedule
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks almost exactly the same as LowerSchedule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially this was a function called form_irmodule in python. I translated it to C++ and renamed it as ScheduleToModule. Unfortunately, form_irmodule is called by some tests. To preserve that behavior I had to split the functions apart again and register them separately in the FFI.
The difference between these is that ScheduleToModule just converts the schedule to a module that hasn't yet been lowered, whereas LowerSchedule converts the schedule into a module and then applies the passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be a little more explicit about this in the documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkonolige does this look good to you?

const std::unordered_map<te::Tensor, tir::Buffer>& binds);
/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
12 changes: 0 additions & 12 deletions include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,6 @@ bool VerifyCompactBuffer(const Stmt& stmt);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);

/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
"""
binds, _ = build_module.get_binds(args, binds)
binds, _ = build_module.get_binds(args, compact=False, binds=binds)
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/driver/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.driver"""
import tvm._ffi

tvm._ffi._init_api("driver", __name__)
155 changes: 20 additions & 135 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,96 +37,58 @@
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var

from . import _ffi_api as ffi


def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.

Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.

compact : bool
If the statement has already bound to a compact buffer.

binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.

Returns
-------
binds: dict
The bind specification

arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = tvm.tir.decl_buffer(
x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, tvm.tir.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
binds, arg_list = ffi.get_binds(args, compact, binds)
return binds, arg_list


def form_irmodule(sch, args, name, binds):
def schedule_to_module(
sch: schedule.Schedule,
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
) -> IRModule:
"""According to the given schedule, form a function.

Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body

args : list of Buffer or Tensor or Var
The argument lists to the function.

name : str
The name of result function.

The name of result function, default name is "main"
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information

Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
pass_ctx = PassContext.current()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)

compact = schedule.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)

stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

func = func.with_attr("global_symbol", name)

if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})
return ffi.schedule_to_module(sch, args, name, binds)


def lower(
inputs: Union[schedule.Schedule, PrimFunc, IRModule],
inp: Union[schedule.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
Expand All @@ -136,7 +98,7 @@ def lower(

Parameters
----------
input : Union[schedule.Schedule, PrimFunc, IRModule]
inputs : Union[schedule.Schedule, PrimFunc, IRModule]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the name inputs when it is a singular schedule, primfunc, or irmodule?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word input is a python built in function so you can't name variables input. Initially the name of the parameter to the function was inputs and the name of the parameter in the documentation was input. I just changed the documentation to match the function signature. I don't like that they are different. I agree that inputs is not an ideal name though and I'm open to suggestions for other names

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use inp instead? I can't really think of a good name either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @electriclilies, I guess even if input is a builtin function in python, it can still be a parameter name? For example this code should work:

def f(input=2):
    print(input)

Please correct me if I am wrong. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is actually that the linter doesn't like it -- I guess it just flags everything called input for some reason.
https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-8110/14/pipeline

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detail! :) It's a pylint warning so you can disable it using # pylint: disable=redefined-builtin. One example is here: https://github.com/apache/tvm/blob/main/python/tvm/topi/image/dilation2d.py#L27.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, thanks! Well I just changed it to inp, so I think I'll leave it that way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

The TE schedule or TensorIR PrimFunc/IRModule to be built

args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
Expand All @@ -160,90 +122,13 @@ def lower(
m : IRModule
The result IRModule
"""
# config setup
pass_ctx = PassContext.current()
instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])

lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

# Phase 0
pass_list = lower_phase0
is_legacy_te_schedule: bool = False

if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for lowering from TE schedule")
mod = form_irmodule(inputs, args, name, binds)
is_legacy_te_schedule = True
elif isinstance(inputs, PrimFunc):
func = inputs.with_attr("global_symbol", name)
if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
elif isinstance(inputs, IRModule):
mod = inputs
else:
raise TypeError(
f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}"
)

# Phase 1
if is_legacy_te_schedule:
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
]
else:
pass_list += [
tvm.tir.transform.LowerInitBlock(),
tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
tvm.tir.transform.ConvertBlocksToOpaque(),
tvm.tir.transform.CompactBufferAllocation(),
tvm.tir.transform.FlattenBuffer(),
]
pass_list += [
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
]

pass_list += lower_phase1

# Phase 2
if not simple_mode:
pass_list += [(tvm.tir.transform.LoopPartition())]

pass_list += [
tvm.tir.transform.VectorizeLoop(not disable_vectorize),
tvm.tir.transform.InjectVirtualThread(),
tvm.tir.transform.InjectDoubleBuffer(),
tvm.tir.transform.StorageRewrite(),
tvm.tir.transform.UnrollLoop(),
]
pass_list += lower_phase2

# Phase 3
pass_list += [
tvm.tir.transform.Simplify(),
tvm.tir.transform.RemoveNoOp(),
]

pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [tvm.tir.transform.HoistIfThenElse()]
pass_list += lower_phase3

# Instrument BoundCheckers
if instrument_bound_checkers:
pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]

optimize = tvm.transform.Sequential(pass_list)
mod = optimize(mod)
return mod
if isinstance(inp, IRModule):
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
return ffi.lower_primfunc(inp, name, simple_mode)
if isinstance(inp, schedule.Schedule):
return ffi.lower_schedule(inp, args, name, binds, simple_mode)
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))


def _build_for_device(input_mod, target, target_host):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

@register_parser
def add_compile_parser(subparsers):
""" Include parser for 'compile' subcommand """
"""Include parser for 'compile' subcommand"""

parser = subparsers.add_parser("compile", help="compile a model.")
parser.set_defaults(func=drive_compile)
Expand Down
Loading