Skip to content

Conversation

@Anndrey24
Copy link
Contributor

When inserting a cache_read / cache_write stage, the tir.AllocateConst statement would be duplicated if its body was not a tir.SeqStmt node (e.g. tir.For), leading to compilation failures. This happened because tir.AllocateConst and tir.DeclBuffer statements are always re-attached to the statement's body after the cache_read / cache_write stage is inserted in it, but the stage was being appended to the whole statement (which already contains the tir.AllocateConst) and not just its body, causing duplications.

This commit also adds a test where the first cache_read stage is inserted into a statement whose body is a tir.For, while the second stage is added to a body that is tir.SeqStmt to check for regressions.

cc @Lunderberg @ekalda @lhutton1

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, and thank you for the fix!

Long-term, we may want to avoid this class of bugs by changing the const Stmt& stmt argument of InsertCacheState to Stmt body. This would avoid having two almost-but-not-quite identical variables within the same scope.

Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Mar 1, 2024
Prior to this commit, the automatic `T.reads()` and `T.writes()`
annotations were only generated for buffers appearing as function
arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
in a `T.block`.  However, inferred `T.reads()` for a buffer defined by
the `"tir.BindParams"` pass would be erroneously missing.  These
annotations may be required for correct scheduling (see discussion in
[PR#16660](apache#16660)).

This commit updates the TVMScript parsing to infer `T.reads()` and
`T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Mar 1, 2024
Prior to this commit, the automatic `T.reads()` and `T.writes()`
annotations were only generated for buffers appearing as function
arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
in a `T.block`.  However, inferred `T.reads()` for a buffer defined by
the `"tir.BindParams"` pass would be erroneously missing.  These
annotations may be required for correct scheduling (see discussion in
[PR#16660](apache#16660)).

This commit updates the TVMScript parsing to infer `T.reads()` and
`T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
@lhutton1
Copy link
Contributor

lhutton1 commented Mar 4, 2024

@tvm-bot rerun

tqchen pushed a commit that referenced this pull request Mar 4, 2024
Prior to this commit, the automatic `T.reads()` and `T.writes()`
annotations were only generated for buffers appearing as function
arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
in a `T.block`.  However, inferred `T.reads()` for a buffer defined by
the `"tir.BindParams"` pass would be erroneously missing.  These
annotations may be required for correct scheduling (see discussion in
[PR#16660](#16660)).

This commit updates the TVMScript parsing to infer `T.reads()` and
`T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Anndrey24, LGTM. Since #16663 was merged would you prefer to address the comment here or address in a follow up?

…primitive

When inserting a `cache_read` / `cache_write` stage, the `tir.AllocateConst` statement would be duplicated if its body was not a `tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always re-attached to the statement's body after the `cache_read` / `cache_write` stage is inserted in it, but the stage was being appended to the whole statement (which already contains the `tir.AllocateConst`) and not just its body, causing duplications.

This commit also adds a test where the first `cache_read` stage is inserted into a statement whose body is a `tir.For`, while the second stage is added to a body that is `tir.SeqStmt` to check for regressions.
@Anndrey24
Copy link
Contributor Author

Oops, I hadn't noticed the fix was already merged. I've also removed the T.reads() now! :)

@lhutton1
Copy link
Contributor

lhutton1 commented Mar 5, 2024

@tvm-bot rerun

1 similar comment
@Anndrey24
Copy link
Contributor Author

@tvm-bot rerun

@lhutton1 lhutton1 merged commit 657880c into apache:main Mar 7, 2024
@lhutton1
Copy link
Contributor

lhutton1 commented Mar 7, 2024

Thanks @Anndrey24 @Lunderberg!

@Anndrey24 Anndrey24 deleted the fix-alloc-const branch March 7, 2024 12:53
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Mar 12, 2024
Prior to this commit, the automatic `T.reads()` and `T.writes()`
annotations were only generated for buffers appearing as function
arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
in a `T.block`.  However, inferred `T.reads()` for a buffer defined by
the `"tir.BindParams"` pass would be erroneously missing.  These
annotations may be required for correct scheduling (see discussion in
[PR#16660](apache#16660)).

This commit updates the TVMScript parsing to infer `T.reads()` and
`T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
Lunderberg pushed a commit to Lunderberg/tvm that referenced this pull request Mar 12, 2024
…primitive (apache#16660)

* [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule primitive

When inserting a `cache_read` / `cache_write` stage, the `tir.AllocateConst` statement would be duplicated if its body was not a `tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always re-attached to the statement's body after the `cache_read` / `cache_write` stage is inserted in it, but the stage was being appended to the whole statement (which already contains the `tir.AllocateConst`) and not just its body, causing duplications.

This commit also adds a test where the first `cache_read` stage is inserted into a statement whose body is a `tir.For`, while the second stage is added to a body that is `tir.SeqStmt` to check for regressions.

* Improve PrimFunc readability

* Remove redundant `T.reads()`
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
Prior to this commit, the automatic `T.reads()` and `T.writes()`
annotations were only generated for buffers appearing as function
arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
in a `T.block`.  However, inferred `T.reads()` for a buffer defined by
the `"tir.BindParams"` pass would be erroneously missing.  These
annotations may be required for correct scheduling (see discussion in
[PR#16660](apache#16660)).

This commit updates the TVMScript parsing to infer `T.reads()` and
`T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
…primitive (apache#16660)

* [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule primitive

When inserting a `cache_read` / `cache_write` stage, the `tir.AllocateConst` statement would be duplicated if its body was not a `tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always re-attached to the statement's body after the `cache_read` / `cache_write` stage is inserted in it, but the stage was being appended to the whole statement (which already contains the `tir.AllocateConst`) and not just its body, causing duplications.

This commit also adds a test where the first `cache_read` stage is inserted into a statement whose body is a `tir.For`, while the second stage is added to a body that is `tir.SeqStmt` to check for regressions.

* Improve PrimFunc readability

* Remove redundant `T.reads()`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants