-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Unify Python and C++ TIR lower API #8110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae28f58
1258223
f57ae0e
8a4f40e
f7e8bbd
e4267fb
fa322f6
51d85fb
8e30116
ca11f46
b117478
41640a1
724db08
47ea819
c01ca1e
49d6c7f
5e43b18
59534cb
5cf0f78
b10d6ad
5d5c068
f2596c5
d6be36e
06532d3
4d7a602
2ed255d
cec3456
b92dfe3
0d69397
30ad407
51bcd3f
eddec7b
1bb8b9d
e932400
4b8a529
6115f3c
b9c8cb6
09d7806
8bfc97e
566de68
cc9458e
4471a01
ca4683d
00fcd7a
7a2b404
8988368
a460f7e
38a121b
55e90cf
d3e18c2
6e5a2ff
4933a5e
b42fb19
4c7049d
e462dab
d067a67
794c606
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| #include <tvm/support/with.h> | ||
| #include <tvm/target/target.h> | ||
| #include <tvm/te/schedule_pass.h> | ||
| #include <tvm/tir/function.h> | ||
|
|
||
| #include <string> | ||
| #include <unordered_map> | ||
|
|
@@ -42,17 +43,68 @@ | |
| #include <vector> | ||
|
|
||
| namespace tvm { | ||
|
|
||
| /*! | ||
| * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) | ||
| * \param mod The IRmodule to lower | ||
| * \param simple_mode Disables the loop partition pass. Defaults to false. | ||
| * \return The result module. | ||
| */ | ||
| TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); | ||
|
|
||
| /*! | ||
| * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list | ||
| * defined in CreatePassList) | ||
| * \param func The PrimFunc to lower | ||
| * \param name The name of the lowered function. | ||
| * \param simple_mode Disables the loop partition pass. Defaults to false. | ||
| * \return The result module. | ||
| */ | ||
| TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, | ||
| bool simple_mode = false); | ||
|
|
||
| /*! | ||
| * \brief Build an IRModule given a schedule, args and binds | ||
| * \param sch The schedule to lower. | ||
| * \brief Build an IRModule given a TE schedule, args and binds. This function also applies | ||
| * the lowering passes defined in CreatePassList. | ||
| * \param sch The TE schedule to lower. | ||
| * \param args The arguments to the function. | ||
| * \param name The name of the lowered function. | ||
| * \param binds Buffer assignments. | ||
| * \param simple_mode Disables the loop partition pass. Defaults to false. | ||
| * \return The result module. | ||
| */ | ||
| TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name, | ||
| const std::unordered_map<te::Tensor, tir::Buffer>& binds); | ||
|
|
||
| TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, | ||
| const std::string& name, | ||
| const std::unordered_map<te::Tensor, tir::Buffer>& binds, | ||
| bool simple_mode = false); | ||
|
|
||
| /*! | ||
| * \brief Build an IRModule given a TE schedule, args and binds. This function also applies | ||
| * the lowering passes defined in CreatePassList. | ||
| * \param sch The TE schedule to lower. | ||
| * \param args The arguments to the function (Array of Tensor, Buffer and Vars) | ||
| * \param name The name of the lowered function. | ||
| * \param binds Buffer assignments. | ||
| * \param simple_mode Disables the loop partition pass. Defaults to false. | ||
| * \return The result module. | ||
| */ | ||
| TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, | ||
| const std::string& name, | ||
| const std::unordered_map<te::Tensor, tir::Buffer>& binds, | ||
| bool simple_mode = false); | ||
|
|
||
| /*! | ||
| * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want | ||
| * to apply lowering passes as well, use LowerSchedule. | ||
| * \param sch The schedule | ||
| * \param args The arguments to the function. | ||
| * \param name The name of the lowered function. | ||
| * \param binds Buffer assignments. | ||
| * \return The result module. | ||
| */ | ||
| IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks almost exactly the same as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initially this was a function called
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you be a little more explicit about this in the documentation?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tkonolige does this look good to you? |
||
| const std::unordered_map<te::Tensor, tir::Buffer>& binds); | ||
| /*! | ||
| * \brief Build a device and host module for a specific target from an IRModule. | ||
| * \param funcs The functions to be built. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """FFI APIs for tvm.driver""" | ||
| import tvm._ffi | ||
|
|
||
| tvm._ffi._init_api("driver", __name__) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,96 +37,58 @@ | |
| from tvm.tir.buffer import Buffer | ||
| from tvm.tir.expr import Var | ||
|
|
||
| from . import _ffi_api as ffi | ||
|
|
||
|
|
||
| def get_binds(args, compact=False, binds=None): | ||
| """Internal function to get binds and arg_list given arguments. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| args : list of Buffer or Tensor or Var | ||
| The argument lists to the function. | ||
|
|
||
| compact : bool | ||
| If the statement has already bound to a compact buffer. | ||
|
|
||
| binds : dict of :any:`Tensor` to :any:`Buffer`, optional | ||
| Dictionary that maps the Tensor to Buffer which specified the data layout | ||
| requirement of the function. By default, a new compact buffer is created | ||
| for each tensor in the argument. | ||
|
|
||
| Returns | ||
| ------- | ||
| binds: dict | ||
| The bind specification | ||
|
|
||
| arg_list: list | ||
| The list of symbolic buffers of arguments. | ||
| """ | ||
| binds = {} if binds is None else binds.copy() | ||
| arg_list = [] | ||
| for x in args: | ||
| if isinstance(x, tensor.Tensor): | ||
| any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) | ||
| buffer_type = "auto_broadcast" if any_dim and not compact else "" | ||
| if x not in binds: | ||
| buf = tvm.tir.decl_buffer( | ||
| x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type | ||
| ) | ||
| binds[x] = buf | ||
| arg_list.append(buf) | ||
| else: | ||
| arg_list.append(binds[x]) | ||
| elif isinstance(x, schedule.Buffer): | ||
| arg_list.append(x) | ||
| elif isinstance(x, tvm.tir.Var): | ||
| arg_list.append(x) | ||
| else: | ||
| raise ValueError("args must be Tensor, Buffer or Var") | ||
| binds, arg_list = ffi.get_binds(args, compact, binds) | ||
| return binds, arg_list | ||
|
|
||
|
|
||
| def form_irmodule(sch, args, name, binds): | ||
| def schedule_to_module( | ||
| sch: schedule.Schedule, | ||
| args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, | ||
| name: str = "main", | ||
| binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, | ||
| ) -> IRModule: | ||
| """According to the given schedule, form a function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| sch : tvm.te.schedule.Schedule | ||
| The given scheduler to form the raw body | ||
|
|
||
| args : list of Buffer or Tensor or Var | ||
| The argument lists to the function. | ||
|
|
||
| name : str | ||
| The name of result function. | ||
|
|
||
| The name of result function, default name is "main" | ||
| binds : dict of :any:`Tensor` to :any:`Buffer`, optional | ||
| The binds information | ||
|
|
||
| Returns | ||
| ------- | ||
| The body formed according to the given schedule | ||
| """ | ||
| # normalize schedule first | ||
| pass_ctx = PassContext.current() | ||
| sch = sch.normalize() | ||
| bounds = schedule.InferBound(sch) | ||
| stmt = schedule.ScheduleOps(sch, bounds) | ||
|
|
||
| compact = schedule.VerifyCompactBuffer(stmt) | ||
| binds, arg_list = get_binds(args, compact, binds) | ||
|
|
||
| stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) | ||
| func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) | ||
|
|
||
| func = func.with_attr("global_symbol", name) | ||
|
|
||
| if pass_ctx.config.get("tir.noalias", True): | ||
| func = func.with_attr("tir.noalias", True) | ||
| return tvm.IRModule({name: func}) | ||
| return ffi.schedule_to_module(sch, args, name, binds) | ||
|
|
||
|
|
||
| def lower( | ||
| inputs: Union[schedule.Schedule, PrimFunc, IRModule], | ||
| inp: Union[schedule.Schedule, PrimFunc, IRModule], | ||
| args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, | ||
| name: str = "main", | ||
| binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, | ||
|
|
@@ -136,7 +98,7 @@ def lower( | |
|
|
||
| Parameters | ||
| ---------- | ||
| input : Union[schedule.Schedule, PrimFunc, IRModule] | ||
| inputs : Union[schedule.Schedule, PrimFunc, IRModule] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the name
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The word input is a python built in function so you can't name variables
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @electriclilies, I guess even if Please correct me if I am wrong. :)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is actually that the linter doesn't like it -- I guess it just flags everything called input for some reason.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the detail! :) It's a pylint warning so you can disable it using
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, thanks! Well I just changed it to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
| The TE schedule or TensorIR PrimFunc/IRModule to be built | ||
|
|
||
| args : Optional[List[Union[Buffer, tensor.Tensor, Var]]] | ||
|
|
@@ -160,90 +122,13 @@ def lower( | |
| m : IRModule | ||
| The result IRModule | ||
| """ | ||
| # config setup | ||
| pass_ctx = PassContext.current() | ||
| instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) | ||
| disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) | ||
| add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) | ||
|
|
||
| lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] | ||
| lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] | ||
| lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] | ||
| lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] | ||
|
|
||
| # Phase 0 | ||
| pass_list = lower_phase0 | ||
| is_legacy_te_schedule: bool = False | ||
|
|
||
| if isinstance(inputs, schedule.Schedule): | ||
| if args is None: | ||
| raise ValueError("args must be given for lowering from TE schedule") | ||
| mod = form_irmodule(inputs, args, name, binds) | ||
| is_legacy_te_schedule = True | ||
| elif isinstance(inputs, PrimFunc): | ||
| func = inputs.with_attr("global_symbol", name) | ||
| if pass_ctx.config.get("tir.noalias", True): | ||
| func = func.with_attr("tir.noalias", True) | ||
| mod = tvm.IRModule({name: func}) | ||
| elif isinstance(inputs, IRModule): | ||
| mod = inputs | ||
| else: | ||
| raise TypeError( | ||
| f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}" | ||
| ) | ||
|
|
||
| # Phase 1 | ||
| if is_legacy_te_schedule: | ||
| pass_list += [ | ||
| tvm.tir.transform.InjectPrefetch(), | ||
| tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), | ||
| ] | ||
| else: | ||
| pass_list += [ | ||
| tvm.tir.transform.LowerInitBlock(), | ||
| tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), | ||
| tvm.tir.transform.ConvertBlocksToOpaque(), | ||
| tvm.tir.transform.CompactBufferAllocation(), | ||
| tvm.tir.transform.FlattenBuffer(), | ||
| ] | ||
| pass_list += [ | ||
| tvm.tir.transform.BF16Legalize(), | ||
| tvm.tir.transform.NarrowDataType(32), | ||
| tvm.tir.transform.Simplify(), | ||
| ] | ||
|
|
||
| pass_list += lower_phase1 | ||
|
|
||
| # Phase 2 | ||
| if not simple_mode: | ||
| pass_list += [(tvm.tir.transform.LoopPartition())] | ||
|
|
||
| pass_list += [ | ||
| tvm.tir.transform.VectorizeLoop(not disable_vectorize), | ||
| tvm.tir.transform.InjectVirtualThread(), | ||
| tvm.tir.transform.InjectDoubleBuffer(), | ||
| tvm.tir.transform.StorageRewrite(), | ||
| tvm.tir.transform.UnrollLoop(), | ||
| ] | ||
| pass_list += lower_phase2 | ||
|
|
||
| # Phase 3 | ||
| pass_list += [ | ||
| tvm.tir.transform.Simplify(), | ||
| tvm.tir.transform.RemoveNoOp(), | ||
| ] | ||
|
|
||
| pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] | ||
| pass_list += [tvm.tir.transform.HoistIfThenElse()] | ||
| pass_list += lower_phase3 | ||
|
|
||
| # Instrument BoundCheckers | ||
| if instrument_bound_checkers: | ||
| pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] | ||
|
|
||
| optimize = tvm.transform.Sequential(pass_list) | ||
| mod = optimize(mod) | ||
| return mod | ||
| if isinstance(inp, IRModule): | ||
| return ffi.lower_module(inp, simple_mode) | ||
| if isinstance(inp, PrimFunc): | ||
| return ffi.lower_primfunc(inp, name, simple_mode) | ||
| if isinstance(inp, schedule.Schedule): | ||
| return ffi.lower_schedule(inp, args, name, binds, simple_mode) | ||
| raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) | ||
|
|
||
|
|
||
| def _build_for_device(input_mod, target, target_host): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a special flag to disable the loop partition pass. Shouldn't the existing
PassContextinfrastructure handle this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not completely sure. This logic was already in the code base, so I just duplicated it and made it more explicit. (For context, when I started this refactor, Jared said he wanted to get it in quickly so I should just try to naively duplicate the existing logic, but he hasn't reviewed it yet so I'm not sure what the timeline is now).
I can try to remove it, but I'm not sure what the best way to go about this is since there are a few tests that call lower directly..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is probably fine for now.