Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, if a TIR variable was required to compute the output of BlockBuilder.call_te, but that TIR variable could not be inferred from the shape of any tensor arguments, it would be provided in an optional tir_vars argument to R.call_tir. In C++, this would be then be accessed as an optional
call->args[2].as<ShapeExprNode>().

This extra argument can cause unexpected bugs. For example, the bug that was fixed in #17086 was caused by RewriteDataflowReshape identifying the output buffer using prim_func->buffer_map.Get(prim_func->params.back()), which is only correct if tir_vars is empty. Rather than fixing these issues as they come up, it would be better to make the general Relax guarantees stronger by removing the tir_vars argument altogether.

Use of extra R.shape parameter to specify additional tir_vars predates the existence of relax::PrimValue, and is no longer required. This commit updates BlockBuilder.call_te to use additional relax.PrimValue arguments to handle symbolic values that cannot be inferred from tensor shapes, rather than tir_vars.

@Lunderberg Lunderberg requested a review from masahi June 12, 2024 19:33
@masahi
Copy link
Member

masahi commented Jun 12, 2024

cc @tqchen @Hzfengsy

@Lunderberg Lunderberg requested a review from sunggg June 18, 2024 18:59
@Lunderberg Lunderberg force-pushed the relax_block_builder_prim_value branch from b463cba to 080defd Compare September 11, 2024 16:25
@Lunderberg
Copy link
Contributor Author

Rebased onto main to avoid stale CI results.

Prior to this commit, if a TIR variable was required to compute the
output of `BlockBuilder.call_te`, but that TIR variable could not be
inferred from the shape of any tensor arguments, it would be provided
in an optional `tir_vars` argument to `R.call_tir`.  In C++, this
would be then be accessed as an optional
`call->args[2].as<ShapeExprNode>()`.

This extra argument can cause unexpected bugs.  For example,
`RewriteDataflowReshape` identifies the output buffer using
`prim_func->buffer_map.Get(prim_func->params.back())`, which is only
correct if `tir_vars` is empty.  Rather than fixing these issues as
they come up, it would be better to make the general Relax guarantees
stronger by removing the `tir_vars` argument altogether.

Use of extra `R.shape` parameter to specify additional `tir_vars`
predates the existence of `relax::PrimValue`, and is no longer
required.  This commit updates `BlockBuilder.call_te` to use
additional `relax.PrimValue` arguments to handle symbolic values that
cannot be inferred from tensor shapes, rather than `tir_vars`.
@Lunderberg Lunderberg force-pushed the relax_block_builder_prim_value branch from 080defd to 95e018d Compare September 12, 2024 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants