Skip to content

Commit 71d6f46

Browse files
authored
fix: select narrow dtype (#10519)
1 parent 57556fe commit 71d6f46

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

src/tir/transforms/narrow_datatype.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,24 @@ class DataTypeRewriter : public StmtExprMutator {
293293
return StmtExprMutator::VisitExpr_(op);
294294
}
295295

296+
PrimExpr VisitExpr_(const SelectNode* op) final {
297+
PrimExpr condition = this->VisitExpr(op->condition);
298+
PrimExpr true_value = this->VisitExpr(op->true_value);
299+
PrimExpr false_value = this->VisitExpr(op->false_value);
300+
if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
301+
false_value.same_as(op->false_value)) {
302+
return GetRef<PrimExpr>(op);
303+
} else {
304+
if (op->true_value.dtype().is_int() && op->false_value.dtype().is_int()) {
305+
int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits());
306+
DataType dtype = true_value.dtype().with_bits(bits);
307+
if (true_value.dtype() != dtype) true_value = cast(dtype, true_value);
308+
if (false_value.dtype() != dtype) false_value = cast(dtype, false_value);
309+
}
310+
return Select(condition, true_value, false_value);
311+
}
312+
}
313+
296314
PrimExpr VisitExpr_(const RampNode* op) final {
297315
PrimExpr base = VisitExpr(op->base);
298316
PrimExpr stride = VisitExpr(op->stride);

0 commit comments

Comments
 (0)