-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes #10172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Not sure whom I should request for reviews, but it seems simimlar to this PR #9582. So ccing the reviewers there @YuchenJin @junrushao1994 @Mousius |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and thanks for the contribution! But I am wondering if we can simply perform such compatible casting when constructing Ramp node? (https://github.com/apache/tvm/tree/main/src/tir/ir/expr.cc#L705)
We should add ICHECK(base.is_int() && stride.is_int()) if the two will only be integers.
| if (base.dtype().is_int()) { | ||
| ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simply assume that base and stride should be of integer types. However, I also noticed that in
Line 705 in 22c488e
| Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { |
Such assumptions are not checked. I am a bit curious if there will be, say base/stride in float types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if they can be floats. So I added that conservative line of code that only acts on integers. If they can only be integer we should add that ICHECK you mention.
|
@lazycal Using this impl (based on yours) in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
ICHECK(base.defined());
ICHECK(stride.defined());
ICHECK(base.dtype().is_scalar());
ICHECK(stride.dtype().is_scalar());
ICHECK_GT(lanes, 1);
ICHECK(base.dtype().is_int());
ICHECK(stride.dtype().is_int());
if (base.dtype() != stride.dtype()) {
size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
DataType dtype = base.dtype().with_bits(bits);
if (base.dtype() != dtype) base = cast(dtype, base);
if (stride.dtype() != dtype) stride = cast(dtype, stride);
}
ObjectPtr<RampNode> node = make_object<RampNode>();
node->dtype = base.dtype().with_lanes(lanes);
node->base = base;
node->stride = stride;
node->lanes = lanes;
node->span = std::move(span);
data_ = std::move(node);
} |
I didn't do what you said because
|
|
@lazycal Fair consideration! I also tried |
|
I've just hit a similar error, when compiling an int8 model with tensorized ops (VNNI): I wonder if this is related. |
…passes (apache#10172) [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.
The following model
triggers two issues regarding
baseandstridedtype mismatch inRamp, one in VectorizeLoop Pass and the other in NarrowDataType Pass. Error message looks likeCheck failed: stride.dtype() == base.dtype() (int32 vs. int64) :.The fix
int32. This PR changes it to use the loop variable's dtype.strideis inferred withint32, butbaseis not (see the added test case for detail). This PR adds an upcasting when rewriting aRampnode that hasbaseandstrideinferred with different number of bits.