Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
63d2352
[TIR] Added BufferLoadNode::LegalizeDtype
Lunderberg Nov 8, 2021
f698ce7
Replacing Store/Load in Stmt/Expr Visitor/Mutator
Lunderberg Nov 10, 2021
4331f36
Removing Store/Load from optimization passes
Lunderberg Nov 8, 2021
ebdc8c7
Removing Store/Load from examples
Lunderberg Nov 15, 2021
bba004a
Replacing Store/Load in StorageFlatten
Lunderberg Nov 8, 2021
7c33db5
Replacing Store/Load in utility passes.
Lunderberg Nov 8, 2021
da05133
Replacing Store/Load in analysis functions
Lunderberg Nov 8, 2021
dba2031
Replacing Store/Load in lowering/legalization passes.
Lunderberg Nov 15, 2021
cbc3a6b
Replacing Load/Store in codegens.
Lunderberg Nov 10, 2021
7399589
[UnitTest] Add unit tests to test physical layout remapping.
Lunderberg Oct 12, 2021
24259f0
Updated tvm::address_of() to hold BufferLoad instead of Load.
Lunderberg Nov 10, 2021
8e9ec52
[TIR] Added IndexMap class.
Lunderberg Oct 8, 2021
0a26801
Updated Buffer::vstore/vload to return BufferLoad/BufferStore objects.
Lunderberg Nov 12, 2021
78d88b5
[TE] Added Stage::transform_layout to the C++ TE implementation.
Lunderberg Oct 8, 2021
4fcb73b
Replace Store/Load with BufferStore/BufferLoad in ir_builder
Lunderberg Dec 11, 2021
97d6dc9
[TE] Added Stage.transform_layout to the Python TE interface.
Lunderberg Oct 11, 2021
a523574
Added pre_flattened_shape/pre_flattened_stride fields to Buffer.
Lunderberg Nov 16, 2021
3211d61
[UnitTest] Test N-d indices exposed to low-level codegen
Lunderberg Oct 22, 2021
74d75dc
[TIR] Added PrimFunc attribute "layout_transform_map", filled from TE.
Lunderberg Oct 11, 2021
ace7525
Added pre_flattened_type.
Lunderberg Jan 6, 2022
89081de
[UnitTest] Added tests for loop iteration order.
Lunderberg Dec 11, 2021
699bb17
[TIR] Added BufferNode::axis_separators
Lunderberg Dec 11, 2021
c3ff6f6
[TIR] Added ApplyLayoutTransforms as part of StorageFlatten.
Lunderberg Oct 12, 2021
acad83f
Update usage of ir_builder where necessary.
Lunderberg Dec 13, 2021
a62f449
[TE] Implement te::Transform
Lunderberg Dec 6, 2021
25ff74c
[TE] Added Stage::set_axis_separators.
Lunderberg Oct 13, 2021
3735b6f
[TIR] Expose tir.transform.ApplyLayoutTransforms for testing
Lunderberg Oct 12, 2021
1994b0f
Breakpoint, removed Store/Load nodes from use.
Lunderberg Nov 17, 2021
ea4c10a
[TE] Rewrite loop iteration order
Lunderberg Dec 6, 2021
ce8e29a
[TE] Fill BufferNode::axis_separators from StageNode
Lunderberg Oct 22, 2021
0a2396b
Breakpoint, layout_transform implemented.
Lunderberg Nov 18, 2021
fb4d71a
[TE] Return transformed iteration variables
Lunderberg Dec 10, 2021
289f192
Breakpoint, axis separators defined.
Lunderberg Nov 18, 2021
3396c05
Breakpoint, expose the transformed axes for use in TE scheduling.
Lunderberg Dec 10, 2021
e34221d
Moved Buffer's pre-flatten information to PrimFunc.
Lunderberg Jan 11, 2022
918ea2d
Updated ethos-u C++ unit tests to remove use of Load/Store.
Lunderberg Jan 24, 2022
0ffe060
Bugfix, layout transformation.
Lunderberg Jan 27, 2022
6e0ca38
In test directory, replacing all instances of T.load.
Lunderberg Jan 12, 2022
417ee2b
Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Jan 12, 2022
e80914b
Added .astype to tvm.script.tir.node.BufferSlice
Lunderberg Jan 12, 2022
d87a561
Replacing all T.store TIR calls.
Lunderberg Jan 14, 2022
09859a6
Added LOG(FATAL) in constructor of Store/Load nodes.
Lunderberg Jan 20, 2022
8b5aab4
Updated tvmscript parser to report error for Store/Load nodes.
Lunderberg Jan 20, 2022
bc1b957
[TVMScript] Added T.preflattened_buffer stmt
Lunderberg Jan 25, 2022
6c111fc
[TVMScript] Updated TVMscript for BufferLoad/BufferStore
Lunderberg Jan 25, 2022
fb1f9fb
Updated test_tvmscript_roundtrip.py for BufferLoad/BufferStore.
Lunderberg Jan 19, 2022
d7b8f06
Updated TIR reference in USMP pool allocation unit tests.
Lunderberg Feb 1, 2022
49eeeca
fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Feb 2, 2022
ff66339
fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Feb 2, 2022
5bebd09
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 2, 2022
2856fd9
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 2, 2022
4d19cc7
fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Feb 2, 2022
00292c3
fixup! In test directory, replacing all instances of T.load.
Lunderberg Feb 2, 2022
029496e
tir.ComputeInline, correct variable count.
Lunderberg Feb 2, 2022
3676f80
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 2, 2022
cafadcc
fixup! Updated Buffer::vstore/vload to return BufferLoad/BufferStore …
Lunderberg Feb 2, 2022
04995f6
fixup! In test directory, replacing all instances of T.load.
Lunderberg Feb 2, 2022
9a22d14
fixup! In test directory, replacing all instances of T.load.
Lunderberg Feb 2, 2022
e1a16b7
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 2, 2022
c69d569
Expose Buffer index flattening function to Python.
Lunderberg Feb 2, 2022
b5644f2
Updated test_tir_buffer.py offset tests.
Lunderberg Feb 2, 2022
777b379
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 2, 2022
fc13666
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 2, 2022
391a28b
fixup! Updated Buffer::vstore/vload to return BufferLoad/BufferStore …
Lunderberg Feb 2, 2022
6c1f861
fixup! Replacing Store/Load in lowering/legalization passes.
Lunderberg Feb 3, 2022
7907083
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 3, 2022
293f54a
fixup! Updated ethos-u C++ unit tests to remove use of Load/Store.
Lunderberg Feb 3, 2022
e38bec0
fixup! Replacing Store/Load in lowering/legalization passes.
Lunderberg Feb 4, 2022
ed21da0
fixup! Updated ethos-u C++ unit tests to remove use of Load/Store.
Lunderberg Feb 4, 2022
1f593d4
fixup! Added .astype to tvm.script.tir.node.BufferSlice
Lunderberg Feb 4, 2022
cf5555a
fixup! In test directory, replacing all instances of T.load.
Lunderberg Feb 4, 2022
2119f0b
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 4, 2022
4877be3
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 4, 2022
a1e8ed4
fixup! In test directory, replacing all instances of T.load.
Lunderberg Feb 4, 2022
2f46152
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 4, 2022
e6579b2
fixup! Replacing Store/Load in lowering/legalization passes.
Lunderberg Feb 4, 2022
5601423
[UnitTests] Added T.preflattened_buffer in expected result
Lunderberg Feb 4, 2022
40eaef4
fixup! In test directory, replacing all instances of T.load.
Lunderberg Feb 4, 2022
5cad482
[UnitTests] Bound checker update, compare against N-d buffer bounds.
Lunderberg Feb 4, 2022
4775f89
Fixup, bound checker vectorize test.
Lunderberg Feb 4, 2022
f990415
fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Feb 4, 2022
b26a2cd
[UnitTest] Fixed breakage in InjectRollingBuffer test.
Lunderberg Feb 4, 2022
f400062
fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Feb 4, 2022
01d699f
[UnitTest] Fixed breakage in flatten buffer unit tests.
Lunderberg Feb 4, 2022
2f93970
fixup! Return buffer object from tvm.tir.script.scope_handler.Allocate
Lunderberg Feb 4, 2022
17b963c
[UnitTests] Fixed breakage in test_tir_buffer.py
Lunderberg Feb 4, 2022
28f6339
fixup! Replacing Load/Store in codegens.
Lunderberg Feb 4, 2022
4d6496f
[UnitTest] ComputeInline, opaque access test updates
Lunderberg Feb 4, 2022
6f879d0
[UnitTest] Fixup, allow unit test to use `ib.pointer()[0]`.
Lunderberg Feb 7, 2022
4de03d4
fixup! Replacing Load/Store in codegens.
Lunderberg Feb 7, 2022
418d1aa
fixup! Replacing Store/Load in lowering/legalization passes.
Lunderberg Feb 7, 2022
0e614af
fixup! Replacing all T.store TIR calls.
Lunderberg Feb 8, 2022
3b3e7fb
Fixed failing codegen c host unit tests.
Lunderberg Feb 8, 2022
7c6ded0
Fixup, StorageFlatten when applied to post-StorageRewrite functions.
Lunderberg Feb 8, 2022
068b179
fixup, StorageFlatten
Lunderberg Feb 8, 2022
b5a1428
Bugfix, correctly represent void* in LLVM IR.
Lunderberg Feb 8, 2022
0b5b840
Update, replace tir.Load with tir.BufferLoad
Lunderberg Feb 8, 2022
249285a
Added TVMScript error check for matching buffer/index dimensionality
Lunderberg Feb 8, 2022
7b98cae
Bugfix, correct return type when lowering custom datatype.
Lunderberg Feb 8, 2022
75819ba
Bugfix, removed unused primfunc from test_tvmscript_complete.py
Lunderberg Feb 8, 2022
a096989
Updated test_meta_schedule_postproc_verify_gpu_code.py TIR
Lunderberg Feb 9, 2022
cb46d7e
Allowed ramp nodes with buffer use analysis.
Lunderberg Feb 9, 2022
5ae79fd
Updated tests in test_meta_schedule_postproc_verify_gpu_code.py
Lunderberg Feb 9, 2022
c0c9329
Updated TIR examples to be compatible with buffer dimension check.
Lunderberg Feb 9, 2022
3b20b42
Corrected section header in docstring.
Lunderberg Feb 9, 2022
63941f5
Corrected indices size check in CogeGenC.
Lunderberg Feb 9, 2022
216fa9a
Fixed breakage in LowerThreadAllreduce.
Lunderberg Feb 10, 2022
0c4194b
[UnitTests] Replaced Store/Load in CUDA codegen tests.
Lunderberg Feb 10, 2022
5329a05
Resolved breakage in C-based codegen for vectorized store/load.
Lunderberg Feb 10, 2022
53c0362
Bugfix, incorrect LCA for buffer access in root scope.
Lunderberg Feb 11, 2022
c636e9b
Added docstrings for TransformNode member variables.
Lunderberg Feb 11, 2022
e40414f
Added TODO for future removal of preflattened_buffer_map.
Lunderberg Feb 11, 2022
27552d6
Fixup, transform layout + cache write tests.
Lunderberg Feb 14, 2022
b85b4ee
Bugfix, correct element type for scalarized access.
Lunderberg Feb 14, 2022
a8b5fa3
Bugfix, cuda buffer indexing when declared as different type.
Lunderberg Feb 15, 2022
e2342dc
Cuda codegen, update reference.
Lunderberg Feb 15, 2022
70d9d3c
Bugfix, lower allreduce
Lunderberg Feb 15, 2022
2e09604
Removed obsolete comment.
Lunderberg Feb 15, 2022
2029ced
Changed PrimFunc constructor preflattened_buffer_map to Optional
Lunderberg Feb 15, 2022
c8f9015
Removed flatten_buffer argument from T.match_buffer.
Lunderberg Feb 15, 2022
8f97159
Correct call to VarUseDefAnalysis::VisitBuffer
Lunderberg Feb 15, 2022
e3e3d89
Reverted unintentional testing change, lanes=2.
Lunderberg Feb 15, 2022
077e2ba
Merged from main into dev branch to resolve conflicts.
Lunderberg Feb 16, 2022
d8b88a9
Updated lower_cross_thread_reduction to use buffer in allreduce
Lunderberg Feb 16, 2022
fa941c9
Updated transform_layout test to disable CSE
Lunderberg Feb 16, 2022
fb14c5e
Updated CSE unit tests to use BufferStore
Lunderberg Feb 16, 2022
120bb5b
Replaced Store/Load for vta.transform and unit tests.
Lunderberg Feb 16, 2022
e4c169d
Updated unit tests for lower_cross_thread_reduction.
Lunderberg Feb 17, 2022
6425882
Updated arange to use scalar tensors.
Lunderberg Feb 17, 2022
4d02048
Fix breakage in ethosu constant encoding.
Lunderberg Feb 18, 2022
8bf6573
Fix breakage in ethosu call argument checks.
Lunderberg Feb 18, 2022
77841ae
Resolve breakage from mismatched shape/index dimensions
Lunderberg Feb 18, 2022
c20709c
Split out encoded parameters from preflattened buffer map.
Lunderberg Feb 18, 2022
521556e
Updated buffer shape/index dimensions to match in more ethosu tests
Lunderberg Feb 18, 2022
7a2eb8e
Fixed lint error
Lunderberg Feb 18, 2022
b08245f
Removed debug code
Lunderberg Feb 18, 2022
f3d17b2
Moved arith::Analyzer local variable to class member
Lunderberg Feb 18, 2022
e8aa9d6
Fixed SSA conversion of allocations.
Lunderberg Feb 22, 2022
9fa1d07
Ethos-u index/buffer dimension updates.
Lunderberg Feb 22, 2022
b8710ad
Merge branch 'main' into physical_layout
Lunderberg Feb 22, 2022
3edb07d
Updated ethosu passes to handle buffer load/store.
Lunderberg Feb 22, 2022
ea0b4f9
Resolved bug in tvmscript printing of duplicate buffers.
Lunderberg Feb 23, 2022
3f52fa3
Fix breakage in ethos-u test_assign_addresses, encode constants
Lunderberg Feb 23, 2022
9d2564c
Merge branch 'main' into physical_layout
Lunderberg Feb 24, 2022
b476517
Apply same changes to T.allocate_const as to T.allocate
Lunderberg Feb 24, 2022
f29d417
Fix lint errors.
Lunderberg Feb 24, 2022
bf65156
Further updates for ethos-u tests.
Lunderberg Feb 24, 2022
2c60f51
Updated ethos.u buffer sizes in test.
Lunderberg Feb 24, 2022
a79b0ac
Updated tir.BindParams to use BufferLoad instead of Load.
Lunderberg Feb 24, 2022
62c3f90
Updated topi.cuda.scan implementation to follow buffer dimensions.
Lunderberg Feb 25, 2022
07dc8ab
Resolved breakage when flattening AllocateConst nodes.
Lunderberg Feb 25, 2022
bc1e5ae
Merge branch 'main' into physical_layout
Lunderberg Feb 25, 2022
cc1f3ae
Resolved breakages from latest merge with main.
Lunderberg Feb 25, 2022
09d33bb
Corrected error in merge.
Lunderberg Feb 25, 2022
24297e3
Use empty indices for rank-0 tensor.
Lunderberg Feb 25, 2022
14676bb
Added ir_builder workaround for 1-d indexing.
Lunderberg Feb 26, 2022
03f9164
Consistent buffer access type in LLVM codegen, to match C codegen
Lunderberg Feb 26, 2022
6d58d23
StorageRewrite, update indices of modified buffers.
Lunderberg Feb 26, 2022
0a9ebe6
Dynamic relay nodes, access 0-d tensors with 0-d indices.
Lunderberg Feb 28, 2022
99357d3
BFloat16 legalization, update buffer type.
Lunderberg Feb 28, 2022
9dd8afb
Updated meshgrid to use 0-d index for 0-d buffer.
Lunderberg Feb 28, 2022
bf2cc9e
Corrected boolean handling in Allocate nodes.
Lunderberg Feb 28, 2022
8dbc571
Added workaround to unpack 1-d Tensor indices into N-d buffer indices.
Lunderberg Feb 28, 2022
f6deec1
Resolved a few more failures in relay tests on cuda.
Lunderberg Feb 28, 2022
77ef980
Resolve linting
Lunderberg Mar 1, 2022
795c3fc
CI bump
Lunderberg Mar 1, 2022
aedf588
Merge branch 'main' into physical_layout
Lunderberg Mar 1, 2022
4703aa2
Updated renormalize_split_pattern tests to use BufferLoad/BufferStore
Lunderberg Mar 1, 2022
94abb53
Fixed cuda codegen checks for BufferStore/Ramp.
Lunderberg Mar 1, 2022
4df4be3
Simplify indices further, needed to avoid cuda register limit.
Lunderberg Mar 1, 2022
3373ecd
fixed dyn onehot shape func accessing 1d buffer with ()
masahi Mar 2, 2022
9835028
Fixed codegen indexing for int4 scalar types.
Lunderberg Mar 2, 2022
942cda1
Temporary workaround for incorrect constant folding.
Lunderberg Mar 2, 2022
298c4fc
Merge remote-tracking branch 'upstream/main' into physical_layout
masahi Mar 3, 2022
7c66f23
s/find_allocate_usage/FindAllocateUsage/g
Lunderberg Mar 3, 2022
e8c0e62
Added buffer type consistency TODO.
Lunderberg Mar 3, 2022
af2adf6
Improved comment on address_of Op.
Lunderberg Mar 3, 2022
50a73e1
Rename LegalizeDtype to LegalizeDType, made private.
Lunderberg Mar 3, 2022
619beb5
fix format and lint errors
adstraw Mar 3, 2022
e930e51
Disable vectorization of AllocateConst buffer in StorageRewrite.
Lunderberg Mar 3, 2022
5066425
Merge branch 'physical_layout' of github.com:Lunderberg/tvm into phys…
Lunderberg Mar 3, 2022
67dfedc
Merge branch 'main' into physical_layout
Lunderberg Mar 3, 2022
4971a02
Pass buffer_map through to the PrimFunc in cmsisnn
Lunderberg Mar 3, 2022
e6e149b
try disabling problematic winograd test case
masahi Mar 4, 2022
054af2e
try different way of buffer mapping in storage_rewrite
masahi Mar 4, 2022
c1d5fc2
Removed unnecessary ramp node in ir_builder.
Lunderberg Mar 4, 2022
f3b73de
Merge branch 'physical_layout' of github.com:Lunderberg/tvm into phys…
Lunderberg Mar 4, 2022
d1a5123
Fix lint error.
Lunderberg Mar 4, 2022
5dca3ff
Updated LLVM codegen for buffer indexing.
Lunderberg Mar 6, 2022
2b14d68
Merge branch 'main' into physical_layout
Lunderberg Mar 6, 2022
084c21c
Resolve lint error.
Lunderberg Mar 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,47 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
return input;
}

/*!
* \brief Copy the function or module, but removes the specified
* attribute.
*
* \param input The thing to annotate (BaseFunc or IRModule)
* \param attr_key The attribute key.
*
* \tparam TFunc The corresponding function or module type.
*
* \returns The new function or module with removed attribute.
*
* \note This function performs copy on write optimization for func and module.
* If we move a uniquely referenced func or module into WithoutAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithoutAttr(std::move(func), "key1");
* func = WithoutAttr(std::move(func), "key2");
*
* \endcode
*/
template <typename TFunc>
inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");

if (input->attrs.defined()) {
TNode* node = input.CopyOnWrite();
node->attrs.CopyOnWrite()->dict.erase(attr_key);
if (node->attrs->dict.size() == 0) {
node->attrs = NullValue<DictAttrs>();
}
}
return input;
}

// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
Expand Down
1 change: 1 addition & 0 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class ComputeOp : public Operation {
Array<IterVar> axis, Array<PrimExpr> body);

TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode);
};

/*!
Expand Down
123 changes: 121 additions & 2 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/index_map.h>

#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -256,6 +257,41 @@ class Stage : public ObjectRef {
* \return reference to self.
*/
TVM_DLL Stage& rolling_buffer(); // NOLINT(*)
/*!
* \brief Defines a layout transformation to be applied to the buffer.
*
* The map from initial_index to final_index must be an
* invertible affine transformation.
*
* \param initial_indices An array of variables to represent a
* value's location in the tensor, using the pre-transformation
* layout. These variables are used as binding occurrences to
* represent the initial indices when applying the initial->final
* mapping, and should not occur elsewhere in the
* Schedule. (i.e. Pass in newly constructed variables, not the
* initial IterVar::var)
*
* \param final_indices An array of expressions, giving the
* value's location in the tensor, using the post-transformation layout.
* Expressions should be in terms of the variables given in
* initial_indices.
*
* \param out_iter_vars An optional output location for the updated
* loop iteration variables.
*
* \return reference to self
*/
TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices,
const Array<PrimExpr>& final_indices,
Array<IterVar>* out_iter_vars = nullptr);
/*! \brief Defines separators between groups of axes.
*
* Used to define `BufferNode::axis_separators`, which has
* additional details.
*
* \param axis_separators A list of axis separators.
*/
TVM_DLL Stage& set_axis_separators(const Array<IntImm>& axis_separators);
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
Expand Down Expand Up @@ -466,9 +502,27 @@ class StageNode : public Object {
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */
/*! \brief All the nodes in the iter var
*
* Each element of all_iter_vars represents an iteration variable
* that may appear within this stage's computation. Any element
* of `all_iter_vars` that is in `leaf_iter_vars` represents a
* variable that is directly defined and usable within the stage's
* computation. All other elements of `all_iter_vars` represent
* variables whose value must be computed from the variables in
* `leaf_iter_vars`. (e.g. Support index k has been split by
* ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in
* `leaf_iter_vars`, while k will not, and must be computed as
* `4*ko + ki`.
*/
Array<IterVar> all_iter_vars;
/*! \brief The current active leaf iter vars in the stage. */
/*! \brief The current active leaf iter vars in the stage.
*
* Each element of leaf_iter_vars will either be replaced with the
* bound index (e.g. threadIdx.x), or will be expanded into a loop
* over the variable's extent. `leaf_iter_vars` is a subset of
* `all_iter_vars`.
*/
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
Expand Down Expand Up @@ -500,6 +554,14 @@ class StageNode : public Object {
bool double_buffer{false};
/*! \brief Whether apply rolling buffer optimization to this stage */
bool rolling_buffer{false};
/*! \brief Layout transformations to be applied onto the stage's tensors. */
Array<IndexMap> layout_transforms;
/*! \brief List of axes after which to divide physical axes.
*
* Used to populate `BufferNode::axis_separators`, which has
* additional details.
*/
Array<IntImm> axis_separators;
/*!
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
Expand All @@ -522,6 +584,8 @@ class StageNode : public Object {
v->Visit("scope", &scope);
v->Visit("is_output", &is_output);
v->Visit("double_buffer", &double_buffer);
v->Visit("layout_transforms", &layout_transforms);
v->Visit("axis_separators", &axis_separators);
v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages);
}
Expand Down Expand Up @@ -771,6 +835,61 @@ class Singleton : public IterVarRelation {
TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
};

/*!
* \brief Transform iterator according to some arbitrary expression.
*/
class TransformNode : public IterVarRelationNode {
public:
/*! \brief The loop variables that were replaced by the transformation.
*
* Prior to applying a layout transformation, these represent the
* loops to iterate over a tensor as it is being computed, following
* a row-major traversal of the tensor's original shape in the
* compute definition.
*/
Array<IterVar> original_variables;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have docs for these variables ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and added.


/*! \brief The variables generated by the transformation.
*
* After to applying a layout transformation, these represent the
* loops to iterate over a tensor as it is being computed, following
* a row-major traversal of the transformed shape of the tensor.
*/
Array<IterVar> transformed_variables;

/*! \brief Map from the original variables to the transformed variables.
*
* Used to determine iterator ranges over the transformed variables.
*/
IndexMap forward_transformation;

/*! \brief Map from transformed variables to the original variables
*
* Used to rewrite expressions containing the original loop iterators
* in terms of the transformed loop iterators.
*/
IndexMap inverse_transformation;

void VisitAttrs(AttrVisitor* v) {
v->Visit("original_variables", &original_variables);
v->Visit("transformed_variables", &transformed_variables);
v->Visit("forward_transformation", &forward_transformation);
v->Visit("inverse_transformation", &inverse_transformation);
}

static constexpr const char* _type_key = "Transform";
TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode);
};

class Transform : public IterVarRelation {
public:
TVM_DLL explicit Transform(Array<IterVar> original_variables,
Array<IterVar> transformed_variables, IndexMap forward_transformation,
IndexMap inverse_transformation);

TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode);
};

/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
Expand Down
44 changes: 38 additions & 6 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,22 @@ class BufferNode : public Object {
Var data;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The shape of the buffer */
/*! \brief The type of the buffer prior to flattening
*
* This contains the shape as it is accessed by
* BufferLoad/BufferStore nodes, and used by the low-level code
* generators.
*/
Array<PrimExpr> shape;
/*!
* \brief Separators between input axes when generating flattened output axes
*
* For buffers representing flat 1-d memory (e.g. any buffer in
* RAM), this should be an empty array. For buffers representing
* non-flat memory, each entry in axis_separators should be the
* first input axis that is part of a new flattened axis.
*/
Array<IntImm> axis_separators;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this was discussed elsewhere (in which case please point me at that), what is the reasoning behind maintaining this in the IR. I would naively assume shape to be transformed on-demand basis ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two separate steps to the transformation, one which re-orders the existing indices and a later transformation that produces flat memory buffers. The former is either on-demand (TIR schedules), or during the building of a PrimFunc (SchedulePostProcToPrimFunc for TE-based schedules). The latter is applied during the lowering steps, either in StorageFlatten (TE-based schedules) or FlattenBuffer (TIR-based schedules). We had some discussion (link to RFC section with some details, link to comments on the RFC) on having the flattening to 1-d be a special case of layout transformations that are applied on-demand. However, @vinx13 brought up that some of the TIR lowering steps would need to apply after axis transformation makes the shape of the transformed buffer available, but before buffer flattening removes that shape information.

Maintaining the axis_separators through to the buffer flattening allows the later transformation to produce buffers representing non-flat memory spaces (e.g. texture memory on GPUs).

/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
Expand Down Expand Up @@ -89,6 +103,7 @@ class BufferNode : public Object {
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("axis_separators", &axis_separators);
v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name);
v->Visit("data_alignment", &data_alignment);
Expand All @@ -98,10 +113,11 @@ class BufferNode : public Object {
}

bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
// Use DefEqual as buffer can define variables
// in its semantics, skip name as name is not important.
// Use DefEqual as buffer can define variables in its semantics,
// skip name as name is not important.
return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
equal.DefEqual(axis_separators, other->axis_separators) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
}
Expand All @@ -112,6 +128,7 @@ class BufferNode : public Object {
hash_reduce.DefHash(shape);
hash_reduce.DefHash(strides);
hash_reduce.DefHash(elem_offset);
hash_reduce.DefHash(axis_separators);
hash_reduce(data_alignment);
hash_reduce(buffer_type);
}
Expand All @@ -127,7 +144,7 @@ class BufferNode : public Object {
* without adjusting for number of lanes. (e.g. The number of
* float16x4 elements in a buffer of type float16x4.)
*/
PrimExpr ElemOffset(Array<PrimExpr> index) const;
Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update comment

Would be nice to have aliases for PhysicalIndex, LogicalIndex ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, though I think there would need to be some additional terminology introduced. The options I had thought about were below.

  • Keep the same function names. Not as informative as it could be, because it doesn't indicate whether the input/output indices correspond to specific indices elsewhere in the flow (e.g. those specified by the user or expressed to low-level code generation).
  • Rename to Array<PrimExpr> GetPhysicalIndex(Array<PrimExpr> logical_index). Not quite true, since the indices passed in should be after any layout transformations have been applied, but before flattening. logical_index would imply that it takes indices exactly as specified in the te.compute definition.
  • Rename to Array<PrimExpr> GetPhysicalIndex(Array<PrimExpr> transformed_index). Not quite true, since it implies that passing in post-transformation indices can always output the physical index. The Buffer object must be modified by the transformation first, before the transformed index can be passed in.
  • Rename to Array<PrimExpr> FlattenIndex(Array<PrimExpr> index). Actually, now that I look at that, I kind of like that one. It doesn't imply additional functionality, and

After writing out the different options, I think it makes sense to rename it to FlattenIndex, along with adding a comment that this should be used after all transformations have been applied, and produces the physical index.


static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand All @@ -146,7 +163,7 @@ class Buffer : public ObjectRef {
// A default value will be picked.
TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Span span = Span());
BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span());

/*!
* \brief Return a new buffer that is equivalent with current one
Expand Down Expand Up @@ -186,6 +203,19 @@ class Buffer : public ObjectRef {
*/
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;

/*!
* \brief Get a flattened version of the buffer
*/
Buffer GetFlattenedBuffer() const;

/*! \brief Determine the offset in the buffer of the given index.
*
* Returns the buffer offset, in number of elements of type dtype,
* without adjusting for number of lanes. (e.g. The number of
* float16x4 elements in a buffer of type float16x4.)
*/
Array<PrimExpr> OffsetOf(Array<PrimExpr> index) const;

/*!
* \brief Return the storage scope associated with this buffer.
*/
Expand All @@ -201,12 +231,14 @@ class Buffer : public ObjectRef {
* \param dtype The content data type.
* \param name The name of the buffer
* \param storage_scope The storage scope associated with this buffer
* \param axis_separators Divisions defining the groups of axes that will be flattened together.
* \param span The location of this object in the source code.
* \return The created buffer.
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", String storage_scope = "", Span span = Span());
String name = "buffer", String storage_scope = "",
Array<IntImm> axis_separators = {}, Span span = Span());

/*!
* \brief Base node for data producers.
Expand Down
11 changes: 8 additions & 3 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,15 @@ TVM_DLL const Op& large_uint_imm();
TVM_DLL const Op& q_multiply_shift();

/*!
* \brief See pesudo code
* \brief Returns the address of an element in the buffer (see pseudocode below).
*
* The number of indices should match the dimensionality of the buffer
* being accessed. If this operation occurs after buffer flattening,
* the number of indices must be supported by the target (i.e. N>1
* only on targets that support non-flat memory buffers).
*
* Handle address_of(Load *op) {
* return &op->buffer_var[index];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider restricting this to the case when len(op->indices) == 1 for now, if we do not need other cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to allow non-flat indices, both for TIR prior to flattening (N-d indices for N-d logical shape of buffer) and for TIR post-flattening on supported targets (N-d indices for N-d physical shape). I've added to the docstring that the buffer address must be compatible with the targets used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Lunderberg To elaborate a bit further, address_of is a very low-level primitive and appears at a later stage of the buffer.

I believe what we are trying to say instead is that "there is a need to build a mechanism to take possible buffer subregion's head address so we can plug them into intrinsics"

While we would find need of taking address of multiple dimensional buffers, the specific design of the instrinsic may not be very desirable at early stage of scheduling, especially when it comes to the case when buffer layout and overall allocation regions changes.

Consider the following example:

@T.prim_func
def myfunc(A: T.Buffer([4,4], "float32"), C: T.Buffer([4,4], "float32")):
     B = T.alloc_buffer([4,4], "float32")
     for i, j in T.grid(2, 2):
           load_2x2_matrix(address_of(B[i*2, j*2], A, i*2, j*2)
           exp_2x2_matrix(address_of(B[i*2, j*2])
           store_2x2_matrix(address_of(B[i*2, j*2], C, i*2, j*2)

In this case, the original intention was to load a 2x2 region onto B, perform some arithmetic operations, then store it back to C. The main problem of this program is how does the address changes as we start to transform the program, say if we swap i and j axis of B, or in this case, what if we compact the layout of B to make it 2x2(as only 2x2 region is touched in each of the loop iterator).

Because address_of is a very low level construct, it is hard to clarify what does it mean to update the IR so that the address_of still makes sense. Additionally, there are also some unspoken cosntraints, for example, in many intrinsics there is a restriction that the submatrix have to stay continguous in memory. Which means B have to follow layout lambda i, j: i//2, j // 2, i %2, j%2 to be lowered.

The high level message is that address_of alone is not enough to capture the intention of slicing a submatrix with possible addressing constraints. TIR introduced a specific construct (buffer subregion match) to address this problem. Which matches a subregion to a new buffer, and the buffer specified all the possible constraints that backend might want to impose.

Process-wise, restricting the semantics to match the existing behavior would meet our need, but also not open up doors for possible future problems. It could be possible that after some deliberation we still think that address_of for multi-dimensional access is needed, then we can update and generalize at that time point

* Handle address_of(BufferLoad *op) {
* return &op->buffer_var[op->indices[0], op->indices[1], ..., op->indices[N-1]];
* }
*/
TVM_DLL const Op& address_of();
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,22 @@ class BufferLoadNode : public PrimExprNode {

static constexpr const char* _type_key = "tir.BufferLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);

private:
/*! \brief Set the dtype based on the buffer/indices
*
* Usually, the BufferLoad's dtype will be the same dtype as the
* buffer. This may have a different number of lanes than the
* buffer's dtype if index values have more than 1 lane.
*
* This function should only be called during construction and after
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TODO, add WithIndices function that uses LegalizeDType, which is called by those friend classes.

* CopyOnWrite. Friend class used here to restrict usage.
*/
void LegalizeDType();
friend class BufferLoad;
friend class CustomDatatypesLowerer;
friend class VectorTypeRewriter;
friend class Vectorizer;
};

/*!
Expand Down
Loading