[Relax][Bugfix] Infer TIR values from shapes inside a tuple #17312
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
If a Relax function contains an
R.match_castthat defines a symbolic shape, and the value provided to theR.match_casthas a known static shape, therelax.transform.CanoncalizeBindings()pass can in-line the known static shape. However, while these known TIR values were only collected if the expression used inR.match_castwas aR.Tensor,R.Shape, andR.Prim(Relax types which may contain symbolic TIR values), they were not collected if theR.match_castexpression was aR.Tuple.For example, while using
R.match_castto convert fromR.Tensor([16])toR.Tensor([batch_size])would identify thatbatch_sizemust be16, usingR.match_castto convert fromR.Tuple(R.Tensor([16]))toR.Tuple(R.Tensor([batch_size]))would not.This commit updates the
InferSymbolicVarMapto collect all symbolic shapes, even if they occur within aR.Tuple.