-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TE][TIR] Implement layout transformations, non-flat memory buffers #9727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
63d2352
f698ce7
4331f36
ebdc8c7
bba004a
7c33db5
da05133
dba2031
cbc3a6b
7399589
24259f0
8e9ec52
0a26801
78d88b5
4fcb73b
97d6dc9
a523574
3211d61
74d75dc
ace7525
89081de
699bb17
c3ff6f6
acad83f
a62f449
25ff74c
3735b6f
1994b0f
ea4c10a
ce8e29a
0a2396b
fb4d71a
289f192
3396c05
e34221d
918ea2d
0ffe060
6e0ca38
417ee2b
e80914b
d87a561
09859a6
8b5aab4
bc1b957
6c111fc
fb1f9fb
d7b8f06
49eeeca
ff66339
5bebd09
2856fd9
4d19cc7
00292c3
029496e
3676f80
cafadcc
04995f6
9a22d14
e1a16b7
c69d569
b5644f2
777b379
fc13666
391a28b
6c1f861
7907083
293f54a
e38bec0
ed21da0
1f593d4
cf5555a
2119f0b
4877be3
a1e8ed4
2f46152
e6579b2
5601423
40eaef4
5cad482
4775f89
f990415
b26a2cd
f400062
01d699f
2f93970
17b963c
28f6339
4d6496f
6f879d0
4de03d4
418d1aa
0e614af
3b3e7fb
7c6ded0
068b179
b5a1428
0b5b840
249285a
7b98cae
75819ba
a096989
cb46d7e
5ae79fd
c0c9329
3b20b42
63941f5
216fa9a
0c4194b
5329a05
53c0362
c636e9b
e40414f
27552d6
b85b4ee
a8b5fa3
e2342dc
70d9d3c
2e09604
2029ced
c8f9015
8f97159
e3e3d89
077e2ba
d8b88a9
fa941c9
fb14c5e
120bb5b
e4c169d
6425882
4d02048
8bf6573
77841ae
c20709c
521556e
7a2eb8e
b08245f
f3d17b2
e8aa9d6
9fa1d07
b8710ad
3edb07d
ea0b4f9
3f52fa3
9d2564c
b476517
f29d417
bf65156
2c60f51
a79b0ac
62c3f90
07dc8ab
bc1e5ae
cc1f3ae
09d33bb
24297e3
14676bb
03f9164
6d58d23
0a9ebe6
99357d3
9dd8afb
bf2cc9e
8dbc571
f6deec1
77ef980
795c3fc
aedf588
4703aa2
94abb53
4df4be3
3373ecd
9835028
942cda1
298c4fc
7c66f23
e8c0e62
af2adf6
50a73e1
619beb5
e930e51
5066425
67dfedc
4971a02
e6e149b
054af2e
c1d5fc2
f3b73de
d1a5123
5dca3ff
2b14d68
084c21c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
| #include <tvm/te/tensor.h> | ||
| #include <tvm/te/tensor_intrin.h> | ||
| #include <tvm/tir/expr.h> | ||
| #include <tvm/tir/index_map.h> | ||
|
|
||
| #include <string> | ||
| #include <unordered_map> | ||
|
|
@@ -256,6 +257,41 @@ class Stage : public ObjectRef { | |
| * \return reference to self. | ||
| */ | ||
| TVM_DLL Stage& rolling_buffer(); // NOLINT(*) | ||
| /*! | ||
| * \brief Defines a layout transformation to be applied to the buffer. | ||
| * | ||
| * The map from initial_index to final_index must be an | ||
| * invertible affine transformation. | ||
| * | ||
| * \param initial_indices An array of variables to represent a | ||
| * value's location in the tensor, using the pre-transformation | ||
| * layout. These variables are used as binding occurrences to | ||
| * represent the initial indices when applying the initial->final | ||
| * mapping, and should not occur elsewhere in the | ||
| * Schedule. (i.e. Pass in newly constructed variables, not the | ||
| * initial IterVar::var) | ||
| * | ||
| * \param final_indices An array of expressions, giving the | ||
| * value's location in the tensor, using the post-transformation layout. | ||
| * Expressions should be in terms of the variables given in | ||
| * initial_indices. | ||
| * | ||
| * \param out_iter_vars An optional output location for the updated | ||
| * loop iteration variables. | ||
| * | ||
| * \return reference to self | ||
| */ | ||
| TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices, | ||
| const Array<PrimExpr>& final_indices, | ||
| Array<IterVar>* out_iter_vars = nullptr); | ||
| /*! \brief Defines separators between groups of axes. | ||
| * | ||
| * Used to define `BufferNode::axis_separators`, which has | ||
| * additional details. | ||
| * | ||
| * \param axis_separators A list of axis separators. | ||
| */ | ||
| TVM_DLL Stage& set_axis_separators(const Array<IntImm>& axis_separators); | ||
| /*! | ||
| * \brief whether the stage has been scheduled. | ||
| * \return whether the stage has been scheduled. | ||
|
|
@@ -466,9 +502,27 @@ class StageNode : public Object { | |
| * while origin_op remains fixed. | ||
| */ | ||
| Operation origin_op; | ||
| /*! \brief All the nodes in the iter var */ | ||
| /*! \brief All the nodes in the iter var | ||
| * | ||
| * Each element of all_iter_vars represents an iteration variable | ||
Lunderberg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| * that may appear within this stage's computation. Any element | ||
| * of `all_iter_vars` that is in `leaf_iter_vars` represents a | ||
| * variable that is directly defined and usable within the stage's | ||
| * computation. All other elements of `all_iter_vars` represent | ||
| * variables whose value must be computed from the variables in | ||
| * `leaf_iter_vars`. (e.g. Support index k has been split by | ||
| * ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in | ||
| * `leaf_iter_vars`, while k will not, and must be computed as | ||
| * `4*ko + ki`. | ||
| */ | ||
| Array<IterVar> all_iter_vars; | ||
| /*! \brief The current active leaf iter vars in the stage. */ | ||
| /*! \brief The current active leaf iter vars in the stage. | ||
| * | ||
| * Each element of leaf_iter_vars will either be replaced with the | ||
| * bound index (e.g. threadIdx.x), or will be expanded into a loop | ||
| * over the variable's extent. `leaf_iter_vars` is a subset of | ||
| * `all_iter_vars`. | ||
| */ | ||
| Array<IterVar> leaf_iter_vars; | ||
| /*! | ||
| * \brief Specify threads to be launched at the stage. | ||
|
|
@@ -500,6 +554,14 @@ class StageNode : public Object { | |
| bool double_buffer{false}; | ||
| /*! \brief Whether apply rolling buffer optimization to this stage */ | ||
| bool rolling_buffer{false}; | ||
| /*! \brief Layout transformations to be applied onto the stage's tensors. */ | ||
| Array<IndexMap> layout_transforms; | ||
| /*! \brief List of axes after which to divide physical axes. | ||
| * | ||
| * Used to populate `BufferNode::axis_separators`, which has | ||
| * additional details. | ||
| */ | ||
| Array<IntImm> axis_separators; | ||
| /*! | ||
| * \brief The parent group of the current stage. | ||
| * The stage cannot be assigned to stages outside the group. | ||
|
|
@@ -522,6 +584,8 @@ class StageNode : public Object { | |
| v->Visit("scope", &scope); | ||
| v->Visit("is_output", &is_output); | ||
| v->Visit("double_buffer", &double_buffer); | ||
| v->Visit("layout_transforms", &layout_transforms); | ||
| v->Visit("axis_separators", &axis_separators); | ||
| v->Visit("group", &group); | ||
| v->Visit("num_child_stages", &num_child_stages); | ||
| } | ||
|
|
@@ -771,6 +835,61 @@ class Singleton : public IterVarRelation { | |
| TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Transform iterator according to some arbitrary expression. | ||
| */ | ||
| class TransformNode : public IterVarRelationNode { | ||
| public: | ||
| /*! \brief The loop variables that were replaced by the transformation. | ||
| * | ||
| * Prior to applying a layout transformation, these represent the | ||
| * loops to iterate over a tensor as it is being computed, following | ||
| * a row-major traversal of the tensor's original shape in the | ||
| * compute definition. | ||
| */ | ||
| Array<IterVar> original_variables; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have docs for these variables ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, and added. |
||
|
|
||
| /*! \brief The variables generated by the transformation. | ||
| * | ||
| * After to applying a layout transformation, these represent the | ||
| * loops to iterate over a tensor as it is being computed, following | ||
| * a row-major traversal of the transformed shape of the tensor. | ||
| */ | ||
| Array<IterVar> transformed_variables; | ||
|
|
||
| /*! \brief Map from the original variables to the transformed variables. | ||
| * | ||
| * Used to determine iterator ranges over the transformed variables. | ||
| */ | ||
| IndexMap forward_transformation; | ||
|
|
||
| /*! \brief Map from transformed variables to the original variables | ||
| * | ||
| * Used to rewrite expressions containing the original loop iterators | ||
| * in terms of the transformed loop iterators. | ||
| */ | ||
| IndexMap inverse_transformation; | ||
|
|
||
| void VisitAttrs(AttrVisitor* v) { | ||
| v->Visit("original_variables", &original_variables); | ||
| v->Visit("transformed_variables", &transformed_variables); | ||
| v->Visit("forward_transformation", &forward_transformation); | ||
| v->Visit("inverse_transformation", &inverse_transformation); | ||
| } | ||
|
|
||
| static constexpr const char* _type_key = "Transform"; | ||
| TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode); | ||
| }; | ||
|
|
||
| class Transform : public IterVarRelation { | ||
| public: | ||
| TVM_DLL explicit Transform(Array<IterVar> original_variables, | ||
| Array<IterVar> transformed_variables, IndexMap forward_transformation, | ||
| IndexMap inverse_transformation); | ||
|
|
||
| TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode); | ||
| }; | ||
|
|
||
| /*! \brief Container for specialization conditions. */ | ||
| class SpecializedConditionNode : public Object { | ||
| public: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,8 +55,22 @@ class BufferNode : public Object { | |
| Var data; | ||
| /*! \brief data type in the content of the tensor */ | ||
| DataType dtype; | ||
| /*! \brief The shape of the buffer */ | ||
| /*! \brief The type of the buffer prior to flattening | ||
| * | ||
| * This contains the shape as it is accessed by | ||
| * BufferLoad/BufferStore nodes, and used by the low-level code | ||
| * generators. | ||
| */ | ||
| Array<PrimExpr> shape; | ||
| /*! | ||
| * \brief Separators between input axes when generating flattened output axes | ||
| * | ||
| * For buffers representing flat 1-d memory (e.g. any buffer in | ||
| * RAM), this should be an empty array. For buffers representing | ||
| * non-flat memory, each entry in axis_separators should be the | ||
| * first input axis that is part of a new flattened axis. | ||
| */ | ||
| Array<IntImm> axis_separators; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this was discussed elsewhere (in which case please point me at that), what is the reasoning behind maintaining this in the IR. I would naively assume shape to be transformed on-demand basis ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two separate steps to the transformation, one which re-orders the existing indices and a later transformation that produces flat memory buffers. The former is either on-demand (TIR schedules), or during the building of a PrimFunc (SchedulePostProcToPrimFunc for TE-based schedules). The latter is applied during the lowering steps, either in StorageFlatten (TE-based schedules) or FlattenBuffer (TIR-based schedules). We had some discussion (link to RFC section with some details, link to comments on the RFC) on having the flattening to 1-d be a special case of layout transformations that are applied on-demand. However, @vinx13 brought up that some of the TIR lowering steps would need to apply after axis transformation makes the shape of the transformed buffer available, but before buffer flattening removes that shape information. Maintaining the |
||
| /*! | ||
| * \brief The strides of each dimension | ||
| * This can be an empty array, indicating array is contiguous | ||
|
|
@@ -89,6 +103,7 @@ class BufferNode : public Object { | |
| v->Visit("dtype", &dtype); | ||
| v->Visit("shape", &shape); | ||
| v->Visit("strides", &strides); | ||
| v->Visit("axis_separators", &axis_separators); | ||
| v->Visit("elem_offset", &elem_offset); | ||
| v->Visit("name", &name); | ||
| v->Visit("data_alignment", &data_alignment); | ||
|
|
@@ -98,10 +113,11 @@ class BufferNode : public Object { | |
| } | ||
|
|
||
| bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { | ||
| // Use DefEqual as buffer can define variables | ||
| // in its semantics, skip name as name is not important. | ||
| // Use DefEqual as buffer can define variables in its semantics, | ||
| // skip name as name is not important. | ||
| return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && | ||
| equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && | ||
| equal.DefEqual(axis_separators, other->axis_separators) && | ||
| equal.DefEqual(elem_offset, other->elem_offset) && | ||
| equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); | ||
| } | ||
|
|
@@ -112,6 +128,7 @@ class BufferNode : public Object { | |
| hash_reduce.DefHash(shape); | ||
| hash_reduce.DefHash(strides); | ||
| hash_reduce.DefHash(elem_offset); | ||
| hash_reduce.DefHash(axis_separators); | ||
| hash_reduce(data_alignment); | ||
| hash_reduce(buffer_type); | ||
| } | ||
|
|
@@ -127,7 +144,7 @@ class BufferNode : public Object { | |
| * without adjusting for number of lanes. (e.g. The number of | ||
| * float16x4 elements in a buffer of type float16x4.) | ||
| */ | ||
| PrimExpr ElemOffset(Array<PrimExpr> index) const; | ||
| Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: update comment Would be nice to have aliases for PhysicalIndex, LogicalIndex ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, though I think there would need to be some additional terminology introduced. The options I had thought about were below.
After writing out the different options, I think it makes sense to rename it to |
||
|
|
||
| static constexpr const char* _type_key = "tir.Buffer"; | ||
| static constexpr const bool _type_has_method_sequal_reduce = true; | ||
|
|
@@ -146,7 +163,7 @@ class Buffer : public ObjectRef { | |
| // A default value will be picked. | ||
| TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides, | ||
| PrimExpr elem_offset, String name, int data_alignment, int offset_factor, | ||
| BufferType buffer_type, Span span = Span()); | ||
| BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span()); | ||
|
|
||
| /*! | ||
| * \brief Return a new buffer that is equivalent with current one | ||
|
|
@@ -186,6 +203,19 @@ class Buffer : public ObjectRef { | |
| */ | ||
| TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const; | ||
|
|
||
| /*! | ||
| * \brief Get a flattened version of the buffer | ||
| */ | ||
| Buffer GetFlattenedBuffer() const; | ||
|
|
||
| /*! \brief Determine the offset in the buffer of the given index. | ||
| * | ||
| * Returns the buffer offset, in number of elements of type dtype, | ||
| * without adjusting for number of lanes. (e.g. The number of | ||
| * float16x4 elements in a buffer of type float16x4.) | ||
| */ | ||
| Array<PrimExpr> OffsetOf(Array<PrimExpr> index) const; | ||
|
|
||
| /*! | ||
| * \brief Return the storage scope associated with this buffer. | ||
| */ | ||
|
|
@@ -201,12 +231,14 @@ class Buffer : public ObjectRef { | |
| * \param dtype The content data type. | ||
| * \param name The name of the buffer | ||
| * \param storage_scope The storage scope associated with this buffer | ||
| * \param axis_separators Divisions defining the groups of axes that will be flattened together. | ||
| * \param span The location of this object in the source code. | ||
| * \return The created buffer. | ||
| * \sa Buffer for complete constructor. | ||
| */ | ||
| TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32), | ||
| String name = "buffer", String storage_scope = "", Span span = Span()); | ||
| String name = "buffer", String storage_scope = "", | ||
| Array<IntImm> axis_separators = {}, Span span = Span()); | ||
|
|
||
| /*! | ||
| * \brief Base node for data producers. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,10 +105,15 @@ TVM_DLL const Op& large_uint_imm(); | |
| TVM_DLL const Op& q_multiply_shift(); | ||
|
|
||
| /*! | ||
| * \brief See pesudo code | ||
| * \brief Returns the address of an element in the buffer (see pseudocode below). | ||
| * | ||
| * The number of indices should match the dimensionality of the buffer | ||
| * being accessed. If this operation occurs after buffer flattening, | ||
| * the number of indices must be supported by the target (i.e. N>1 | ||
| * only on targets that support non-flat memory buffers). | ||
| * | ||
| * Handle address_of(Load *op) { | ||
| * return &op->buffer_var[index]; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider restricting this to the case when len(op->indices) == 1 for now, if we do not need other cases
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it needs to allow non-flat indices, both for TIR prior to flattening (N-d indices for N-d logical shape of buffer) and for TIR post-flattening on supported targets (N-d indices for N-d physical shape). I've added to the docstring that the buffer address must be compatible with the targets used.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @Lunderberg To elaborate a bit further, I believe what we are trying to say instead is that "there is a need to build a mechanism to take possible buffer subregion's head address so we can plug them into intrinsics" While we would find need of taking address of multiple dimensional buffers, the specific design of the instrinsic may not be very desirable at early stage of scheduling, especially when it comes to the case when buffer layout and overall allocation regions changes. Consider the following example: @T.prim_func
def myfunc(A: T.Buffer([4,4], "float32"), C: T.Buffer([4,4], "float32")):
B = T.alloc_buffer([4,4], "float32")
for i, j in T.grid(2, 2):
load_2x2_matrix(address_of(B[i*2, j*2], A, i*2, j*2)
exp_2x2_matrix(address_of(B[i*2, j*2])
store_2x2_matrix(address_of(B[i*2, j*2], C, i*2, j*2)In this case, the original intention was to load a 2x2 region onto B, perform some arithmetic operations, then store it back to C. The main problem of this program is how does the address changes as we start to transform the program, say if we swap i and j axis of B, or in this case, what if we compact the layout of B to make it 2x2(as only 2x2 region is touched in each of the loop iterator). Because The high level message is that Process-wise, restricting the semantics to match the existing behavior would meet our need, but also not open up doors for possible future problems. It could be possible that after some deliberation we still think that |
||
| * Handle address_of(BufferLoad *op) { | ||
| * return &op->buffer_var[op->indices[0], op->indices[1], ..., op->indices[N-1]]; | ||
| * } | ||
| */ | ||
| TVM_DLL const Op& address_of(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -630,6 +630,22 @@ class BufferLoadNode : public PrimExprNode { | |
|
|
||
| static constexpr const char* _type_key = "tir.BufferLoad"; | ||
| TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); | ||
|
|
||
| private: | ||
| /*! \brief Set the dtype based on the buffer/indices | ||
| * | ||
| * Usually, the BufferLoad's dtype will be the same dtype as the | ||
| * buffer. This may have a different number of lanes than the | ||
| * buffer's dtype if index values have more than 1 lane. | ||
| * | ||
| * This function should only be called during construction and after | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add TODO, add WithIndices function that uses LegalizeDType, which is called by those friend classes. |
||
| * CopyOnWrite. Friend class used here to restrict usage. | ||
| */ | ||
| void LegalizeDType(); | ||
| friend class BufferLoad; | ||
| friend class CustomDatatypesLowerer; | ||
| friend class VectorTypeRewriter; | ||
| friend class Vectorizer; | ||
| }; | ||
|
|
||
| /*! | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.