@@ -55,6 +55,7 @@ class BF16PromoteRewriter : public StmtExprMutator {
5555 PrimExpr VisitExpr_ (const LENode* op) final ;
5656 PrimExpr VisitExpr_ (const GTNode* op) final ;
5757 PrimExpr VisitExpr_ (const GENode* op) final ;
58+ PrimExpr VisitExpr_ (const CallNode* op) final ;
5859};
5960
6061#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH (OP, FUNC, NEEDCAST ) \
@@ -88,6 +89,26 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false)
8889DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH (GTNode, operator >, false )
8990DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH (GENode, operator >=, false )
9091
92+ PrimExpr BF16PromoteRewriter::VisitExpr_ (const CallNode* op) {
93+ Array<PrimExpr> args;
94+ for (auto & arg : op->args ) {
95+ PrimExpr x = this ->VisitExpr (arg);
96+ if (x.dtype ().is_bfloat16 ()) {
97+ DataType fp32_dtype (kDLFloat , 32 , x.dtype ().lanes ());
98+ args.push_back (Cast (fp32_dtype, {x}, op->span ));
99+ } else {
100+ args.push_back (x);
101+ }
102+ }
103+ if (op->dtype .is_bfloat16 ()) {
104+ DataType fp32_dtype (kDLFloat , 32 , op->dtype .lanes ());
105+ PrimExpr result_fp32 = Call (fp32_dtype, op->op , args, op->span );
106+ return Cast (op->dtype , {result_fp32}, op->span );
107+ } else {
108+ return Call (op->dtype , op->op , args, op->span );
109+ }
110+ }
111+
91112/*
92113 * Eliminate verbose casting between fp32 and bf16
93114 * Checks if the AST has the pattern:
0 commit comments