Skip to content

Commit efd7c45

Browse files
authored
[TIR] [bfloat16] add bfloat16 promotion for CallNode (#12370)
* add bfloat16 promotion for CallNode * add softmax to bfloat16 build test
1 parent 88928a4 commit efd7c45

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/tir/transforms/bf16_legalize.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class BF16PromoteRewriter : public StmtExprMutator {
5555
PrimExpr VisitExpr_(const LENode* op) final;
5656
PrimExpr VisitExpr_(const GTNode* op) final;
5757
PrimExpr VisitExpr_(const GENode* op) final;
58+
PrimExpr VisitExpr_(const CallNode* op) final;
5859
};
5960

6061
#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST) \
@@ -88,6 +89,26 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false)
8889
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false)
8990
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false)
9091

92+
PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) {
93+
Array<PrimExpr> args;
94+
for (auto& arg : op->args) {
95+
PrimExpr x = this->VisitExpr(arg);
96+
if (x.dtype().is_bfloat16()) {
97+
DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes());
98+
args.push_back(Cast(fp32_dtype, {x}, op->span));
99+
} else {
100+
args.push_back(x);
101+
}
102+
}
103+
if (op->dtype.is_bfloat16()) {
104+
DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes());
105+
PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span);
106+
return Cast(op->dtype, {result_fp32}, op->span);
107+
} else {
108+
return Call(op->dtype, op->op, args, op->span);
109+
}
110+
}
111+
91112
/*
92113
* Eliminate verbose casting between fp32 and bf16
93114
* Checks if the AST has the pattern:

tests/python/relay/test_cpp_build_module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def test_bf16_build():
127127
relu_bf16 = relay.nn.relu(bn_bf16[0])
128128
maxpool_bf16 = relay.nn.max_pool2d(relu_bf16, pool_size=(2, 2), strides=(2, 2))
129129
avgpool_bf16 = relay.nn.avg_pool2d(maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
130-
mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
130+
flattened_bf16 = relay.nn.batch_flatten(avgpool_bf16)
131+
softmax_bf16 = relay.nn.softmax(flattened_bf16)
132+
mod_bf16 = tvm.IRModule.from_expr(softmax_bf16)
131133
with tvm.transform.PassContext(opt_level=3):
132134
relay.build(mod_bf16, target="llvm", params=params)
133135

0 commit comments

Comments
 (0)