Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e97d9f1
[Draft][TIR] Remove PrimFuncNode::preflattened_buffer_map
Lunderberg Apr 7, 2022
33b5b63
Fix lint errors
Lunderberg Apr 11, 2022
d855300
Replacing T.preflattened_buffer with T.buffer_decl in unit tests
Lunderberg Apr 11, 2022
7bdbf5d
Remove preflattened_buffer from TVMScript stubs
Lunderberg Apr 15, 2022
06b0260
Removed additional preflattened usages after rebase
Lunderberg Apr 19, 2022
ae16274
Updated tir::PrimFunc usage in cmsisnn contrib
Lunderberg Apr 19, 2022
5f47164
Removing more usage of preflattened from python files
Lunderberg Apr 19, 2022
83028e1
Removing duplicate buffer names
Lunderberg Apr 21, 2022
0773279
Corrected linting errors
Lunderberg Apr 21, 2022
877cc24
Updated BufferAllocationLocator to ignore aliases of arg buffers
Lunderberg Apr 25, 2022
2fde1a9
Replaced more preflatten occurrences
Lunderberg May 18, 2022
074d688
Merge branch 'main' into remove_preflattened_buffer_map
Lunderberg May 18, 2022
c95b186
Removed preflatten usage from merge
Lunderberg May 18, 2022
df85734
T.handle -> T.Buffer in PrimFunc args for AOT test
Lunderberg May 18, 2022
a912f03
Merge branch 'main' into remove_preflattened_buffer_map
Lunderberg May 27, 2022
31ddb16
Merge branch 'main' into remove_preflattened_buffer_map
Lunderberg Sep 27, 2022
65f3ebd
More removal of preflattened instances
Lunderberg Sep 28, 2022
fdbc64d
lint fixes
Lunderberg Sep 28, 2022
cfa7687
Merge branch 'main' into remove_preflattened_buffer_map
Lunderberg Nov 9, 2022
bc99a55
Update following merge
Lunderberg Nov 9, 2022
ff9fb58
Directly write vector function instead of relying on tvm.lower
Lunderberg Nov 9, 2022
cb15f96
Updates to ethos-u constant encoding to avoid breakage
Lunderberg Nov 9, 2022
9d7853f
A few more ethos-u updates
Lunderberg Nov 9, 2022
c2b03bd
Merge branch 'main' into remove_preflattened_buffer_map
Lunderberg Nov 10, 2022
bd4b2dc
Updates following latest merge
Lunderberg Nov 10, 2022
d201e62
Fixed updates in TVMScript for test_replace_conv2d
Lunderberg Nov 10, 2022
c37ec25
Fixing breakage in test_hoist_allocates.py
Lunderberg Nov 10, 2022
299b7b3
Resolve breakage in test_merge_constants.py
Lunderberg Nov 10, 2022
de1d5fe
Remove some debug code that broke PassContext
Lunderberg Nov 10, 2022
d5fe2ed
Merge branch 'main' into remove_preflattened_buffer_map
Lunderberg Nov 15, 2022
2292229
Updated TVMScript representation of Ramp
Lunderberg Nov 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class PrimFuncFrameNode : public TIRFrameNode {
Optional<Type> ret_type;
/*! \brief Maps some parameters to specific Buffer data structures. */
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
/*! \brief The buffer map prior to flattening. */
Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
/*! \brief Additional attributes storing the meta-data */
Optional<Map<String, ObjectRef>> attrs;
/*! \brief The variable map bound to thread env. */
Expand All @@ -90,7 +88,6 @@ class PrimFuncFrameNode : public TIRFrameNode {
v->Visit("args", &args);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
v->Visit("env_threads", &env_threads);
v->Visit("root_alloc_buffers", &root_alloc_buffers);
Expand Down
20 changes: 0 additions & 20 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,6 @@ Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = Data
int align = -1, int offset_factor = 0, String buffer_type = "default",
Array<IntImm> axis_separators = {});

/*!
* \brief The pre-flattened buffer statement.
* \param postflattened_buffer The original buffer to be flattened.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
*/
void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
String storage_scope = "global", int align = -1, int offset_factor = 0,
String buffer_type = "default", Array<IntImm> axis_separators = {});

/*!
* \brief The block declaration statement.
* \param name The name of the block.
Expand Down
43 changes: 11 additions & 32 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,22 @@ class PrimFuncNode : public BaseFuncNode {
* While we could have express parameter unpacking and constraint using
* normal statements, making buffer_map as first class citizen of PrimFunc
* will make program analysis much easier.
*/
Map<tir::Var, Buffer> buffer_map;

/*! \brief The buffer map prior to flattening.
*
* This contains the buffers as they exists prior to flattening, and
* is used for validating an input tensor passed into the packed
* API. Any buffer that is present in `buffer_map` but not present
* in `preflattened_buffer_map` is assumed to be the same before
* and after flattening (e.g. a 1-d tensor that is backed by 1-d
* flat memory).
*
* TODO(Lunderberg): Remove preflattened_buffer_map, and instead
* declare each flattened buffer as aliasing the original tensor
* shape. This should include improving the StmtExprMutator to
* provide easier interactions with Buffer objects, so that the
* bookkeeping of relationships between buffers doesn't need to be
* repeated across several transforms.
* Prior to buffer flattening, which is performed either in
* StorageFlatten for TE-based schedules or in FlattenBuffer for
* TIR-based schedules, these buffer objects are used directly in
* the body of the function. After buffer flattening, these buffer
* objects remain unflattened for use in argument validation, but
* all usage in the body of the function is done through a
* flattened alias of the buffer.
*/
Map<tir::Var, Buffer> preflattened_buffer_map;
Map<tir::Var, Buffer> buffer_map;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
Expand All @@ -123,15 +112,13 @@ class PrimFuncNode : public BaseFuncNode {
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
equal(preflattened_buffer_map, other->preflattened_buffer_map) &&
equal(ret_type, other->ret_type) && equal(body, other->body) &&
equal(attrs, other->attrs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(params);
hash_reduce(buffer_map);
hash_reduce(preflattened_buffer_map);
hash_reduce(ret_type);
hash_reduce(body);
hash_reduce(attrs);
Expand Down Expand Up @@ -169,21 +156,13 @@ class PrimFunc : public BaseFunc {
* PrimFunc. (e.g. a buffer of shape ``[1024]`` originally
* generated as a tensor of shape ``[32, 32]``)
*
* \param preflattened_buffer_map The buffer map for
* parameter buffer unpacking. This contains buffer
* objects as they are expected to be passed in by the
* callee. (e.g. a buffer of shape ``[32, 32]`` originally
* generated as a tensor of shape ``[32, 32]``)
*
* \param attrs Additional function attributes.
*
* \param span The location of this object in the source code.
*/
TVM_DLL PrimFunc(
Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
Optional<Map<tir::Var, Buffer>> preflattened_buffer_map = Optional<Map<tir::Var, Buffer>>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
Expand Down
77 changes: 53 additions & 24 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def _ftransform(f, mod, ctx):
new_body,
f.ret_type,
new_buffer_map,
f.preflattened_buffer_map,
f.attrs,
f.span,
)
Expand Down Expand Up @@ -327,7 +326,7 @@ def EncodeConstants(const_dict):
"""
new_const_dict = {}

def collect_encoding_definitions(stmt, old_buffer_to_const):
def collect_encoding_definitions(stmt, old_buffer_var_to_const):
# Map from copy destination to copy source.
copy_map = {}
# List of buffer copies that occurred
Expand Down Expand Up @@ -376,7 +375,7 @@ def _declare_constant_buffer(old_buffer, encoded_constants, split_idx):
def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func):
"""Encode the weights or align the bias either for one or two cores,
depending on the variant."""
constant = old_buffer_to_const[buffer1]
constant = old_buffer_var_to_const[buffer1.data]

# If we have just one core, encode the whole constant
if buffer2 is None:
Expand Down Expand Up @@ -471,7 +470,12 @@ def _visit(stmt):
}

def transform_stmt(
stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx
stmt,
buf_remap,
var_remap,
pointer_to_buffer,
new_buffer_var_to_const,
new_buffer_to_split_idx,
):
def _visit_rewrite(stmt):
if isinstance(stmt, tvm.tir.Call):
Expand All @@ -485,7 +489,7 @@ def _visit_rewrite(stmt):
# encoded buffer, the current should be a length.
if (
isinstance(prev_arg, tvm.tir.BufferLoad)
and prev_arg.buffer in new_buffer_to_const
and prev_arg.buffer.data in new_buffer_var_to_const
):
buffer_size = np.prod(list(prev_arg.buffer.shape))
arg = buffer_size
Expand Down Expand Up @@ -554,28 +558,56 @@ def _visit_rewrite(stmt):
["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"],
)

def _collect_parameter_buffer_aliases(prim_func):
buffer_vars = {}
for param in prim_func.params:
if param in prim_func.buffer_map:
buf = prim_func.buffer_map[param]
buffer_vars[buf.data] = {buf}

def visit(node):
if isinstance(node, (tvm.tir.BufferStore, tvm.tir.BufferLoad, tvm.tir.DeclBuffer)):
buf = node.buffer
if buf.data in buffer_vars:
buffer_vars[buf.data].add(buf)

tvm.tir.stmt_functor.post_order_visit(prim_func.body, visit)
return buffer_vars

def _ftransform(f, mod, ctx):
param_buffer_var_usage = _collect_parameter_buffer_aliases(f)

# Step 0: Unpack the constant dictionary in terms of the
# functions buffers.
old_buffer_to_const = {}
old_buffer_var_to_const = {}
for i, param in enumerate(f.params):
if i in const_dict:
old_buffer_to_const[f.buffer_map[param]] = const_dict[i]
old_buffer_var_to_const[f.buffer_map[param].data] = const_dict[i]

# Step 1: Collect information on the buffers that will be
# replaced by encodings.
buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const)
buffer_information = collect_encoding_definitions(f.body, old_buffer_var_to_const)

# Step 2: Generate variable/buffer remaps, based on the
# collected information.
buf_remap = {}
new_buffer_to_const = {}
new_buffer_var_to_const = {}
new_buffer_to_split_idx = {}

def define_remap(old_buf, new_buf):
try:
old_buffers = param_buffer_var_usage[old_buf.data]
except KeyError:
old_buffers = [old_buf]

for old_buffer in old_buffers:
buf_remap[old_buffer] = new_buf

# Any encoded buffers must be replaced
for info in buffer_information["constant_buffer_replacements"]:
buf_remap[info["old_buffer"]] = info["new_buffer"]
new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"]
define_remap(info["old_buffer"], info["new_buffer"])

new_buffer_var_to_const[info["new_buffer"].data] = info["encoded_constants"]

if info["split_idx"]:
new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"]
Expand All @@ -596,9 +628,11 @@ def _ftransform(f, mod, ctx):
name=copy_dest.name,
scope=copy_dest.scope(),
)
buf_remap[copy_dest] = new_dest
if copy_source in new_buffer_to_const:
new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source]
define_remap(copy_dest, new_dest)
if copy_source.data in new_buffer_var_to_const:
new_buffer_var_to_const[new_dest.data] = new_buffer_var_to_const[
copy_source.data
]

if copy_source in new_buffer_to_split_idx:
new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source]
Expand All @@ -615,7 +649,7 @@ def _ftransform(f, mod, ctx):
buf_remap,
var_remap,
pointer_to_buffer,
new_buffer_to_const,
new_buffer_var_to_const,
new_buffer_to_split_idx,
)

Expand All @@ -626,10 +660,10 @@ def _ftransform(f, mod, ctx):
if buffer in buf_remap:
buffer = buf_remap[buffer]

if buffer in new_buffer_to_const:
new_const_dict[i] = new_buffer_to_const[buffer].flatten()
elif buffer in old_buffer_to_const:
new_const_dict[i] = old_buffer_to_const[buffer].flatten()
if buffer.data in new_buffer_var_to_const:
new_const_dict[i] = new_buffer_var_to_const[buffer.data].flatten()
elif buffer.data in old_buffer_var_to_const:
new_const_dict[i] = old_buffer_var_to_const[buffer.data].flatten()

new_buffer_map[param] = buffer

Expand All @@ -638,7 +672,6 @@ def _ftransform(f, mod, ctx):
new_body,
f.ret_type,
new_buffer_map,
f.preflattened_buffer_map,
f.attrs,
f.span,
)
Expand Down Expand Up @@ -873,7 +906,6 @@ def CreatePrimFuncWithoutConstants(const_dict):
def _ftransform(f, mod, ctx):
new_params = list()
new_buffer_map = dict()
new_preflattened_buffer_map = dict()
for param_idx in const_dict.keys():
# We are using buffer_var to key the constants as
# PrimFunc params of constants will be removed.
Expand All @@ -882,14 +914,11 @@ def _ftransform(f, mod, ctx):
if i not in const_dict.keys():
new_params.append(param)
new_buffer_map[param] = f.buffer_map[param]
if param in f.preflattened_buffer_map:
new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param]
return tvm.tir.PrimFunc(
new_params,
f.body,
f.ret_type,
new_buffer_map,
new_preflattened_buffer_map,
f.attrs,
f.span,
)
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ class ContextMaintainer:
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
Expand Down Expand Up @@ -160,7 +158,6 @@ def __init__(
# function context
self.func_params = []
self.func_buffer_map = {}
self.func_preflattened_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
Expand Down
Loading