File tree Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -293,6 +293,24 @@ class DataTypeRewriter : public StmtExprMutator {
293293 return StmtExprMutator::VisitExpr_ (op);
294294 }
295295
296+ PrimExpr VisitExpr_ (const SelectNode* op) final {
297+ PrimExpr condition = this ->VisitExpr (op->condition );
298+ PrimExpr true_value = this ->VisitExpr (op->true_value );
299+ PrimExpr false_value = this ->VisitExpr (op->false_value );
300+ if (condition.same_as (op->condition ) && true_value.same_as (op->true_value ) &&
301+ false_value.same_as (op->false_value )) {
302+ return GetRef<PrimExpr>(op);
303+ } else {
304+ if (op->true_value .dtype ().is_int () && op->false_value .dtype ().is_int ()) {
305+ int bits = std::max (true_value.dtype ().bits (), false_value.dtype ().bits ());
306+ DataType dtype = true_value.dtype ().with_bits (bits);
307+ if (true_value.dtype () != dtype) true_value = cast (dtype, true_value);
308+ if (false_value.dtype () != dtype) false_value = cast (dtype, false_value);
309+ }
310+ return Select (condition, true_value, false_value);
311+ }
312+ }
313+
296314 PrimExpr VisitExpr_ (const RampNode* op) final {
297315 PrimExpr base = VisitExpr (op->base );
298316 PrimExpr stride = VisitExpr (op->stride );
You can’t perform that action at this time.
0 commit comments