From 539e5e09b30bf45b525a5a96f5b784719b349832 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 8 Nov 2021 17:00:28 +0800 Subject: [PATCH 01/16] update AMP table to enable ResNet50 conversion --- python/tvm/relay/transform/mixed_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 5018ba9ba9a7..9bf6d7abd868 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -40,6 +40,8 @@ "nn.conv3d_transpose", "nn.dense", "nn.batch_matmul", + "nn.bias_add", + "nn.batch_norm", ] DEFAULT_FOLLOW_LIST = [ # These ops add new data or change shape @@ -80,8 +82,6 @@ "subtract", "multiply", "divide", - "nn.bias_add", - "nn.batch_norm", "sqrt", "shape_of", # Simple activations From 9fd4dd906866d95571de1745e5f4ce005f11ad6b Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 8 Nov 2021 17:19:39 +0800 Subject: [PATCH 02/16] add runtime datatype dispatch for BFloat16 --- src/relay/transforms/pattern_utils.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 69ad20a7ceaf..d55ebdc20bc5 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -63,6 +63,9 @@ namespace relay { } else if (type == DataType::Float(16)) { \ typedef uint16_t DType; \ { __VA_ARGS__ } \ + } else if (type == DataType::BFloat(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ } else if (type == DataType::Int(64)) { \ typedef int64_t DType; \ { __VA_ARGS__ } \ @@ -259,6 +262,11 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { // storage is uint16_t *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *static_cast(arr->data) = + __truncXfYf2__(static_cast(value)); } else { *static_cast(arr->data) = value; } @@ -286,6 +294,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s *(static_cast(arr->data) + i) = __truncXfYf2__( static_cast(value[i])); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); } else { *(static_cast(arr->data) + i) = value[i]; } @@ -314,6 +328,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s *(static_cast(arr->data) + i) = __truncXfYf2__( static_cast(value[i])); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); } else { *(static_cast(arr->data) + i) = value[i]; } From 8bf6144af78e3bfeca6b7cc3818510af9f56fa2b Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 15 Nov 2021 08:26:01 +0800 Subject: [PATCH 03/16] skip asserts for uint16 for bf16 compatibility --- src/tir/transforms/arg_binder.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index d3ab32cbd7f9..f4c2a4d4b1e0 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -169,7 +169,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { + if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || + dtype == DataType::Int(1) || dtype == DataType::UInt(16))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); From a786001527bbf7c047dcbd3a6d924abc12a868a5 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 16 Nov 2021 15:13:55 +0800 Subject: [PATCH 04/16] add bf16 cast for the unary intrinsic operators --- include/tvm/tir/op.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9cf7d0a3cd1f..2de3056f0bd1 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -835,10 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s Span span = Span()); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x}, span); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType srcType = x.dtype(); \ + DataType dstType(kDLFloat, 32, srcType.lanes()); \ + PrimExpr castX = tir::Cast(dstType, {x}, span); \ + PrimExpr result = tir::Call(dstType, op, {castX}, span); \ + return tir::Cast(srcType, {result}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp); From c959702c34d158be91503f9eb27930b64ce2e4bc Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 18 Nov 2021 10:15:45 +0800 Subject: [PATCH 05/16] enable "bf16<-->fp32<-->any dtype" casting --- src/tir/transforms/bf16_legalize.cc | 30 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 76845cbebd2a..b2213e1fefdf 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -179,25 +179,23 @@ class BF16LowerRewriter : public StmtExprMutator { using StmtExprMutator::operator(); PrimExpr VisitExpr_(const CastNode* op) final { - auto op_val = StmtExprMutator::VisitExpr(op->value); - if (op->value->dtype.is_bfloat16()) { - // if is cast_from_bf16, check if is to fp32 - ICHECK(op->dtype.is_float() && op->dtype.bits() == 32); - auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Cast(uint32_dtype, op_val); - // to be endian invariant. - return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}); - } else if (op->dtype.is_bfloat16()) { - // if is cast_to_bf16, check if op->value is fp32 - ICHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); - auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}); - auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); + PrimExpr op_val = StmtExprMutator::VisitExpr(op->value); + DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes()); + DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes()); + if (op->value->dtype.is_bfloat16()) { // cast from bf16 + PrimExpr uint32_v = Cast(uint32_dtype, op_val); + PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(), {uint32_v << 16}); + bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32; + return is_to_float32 ? float32_v : Cast(op->dtype, float32_v); + } else if (op->dtype.is_bfloat16()) { // cast to bf16 + bool is_from_float32 = op->value->dtype.is_float() && op->value->dtype.bits() == 32; + PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype, op_val); + PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(), {float32_v}); + DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); return static_cast((U32 + rounding_bias) >> 16);*/ - auto rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); - // to be endian invariant. + PrimExpr rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); } if (op->value.same_as(op_val)) return GetRef(op); From 7e77f567794b49a721cc91ec5b8b9d7f2702ee90 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 18 Nov 2021 13:44:27 +0800 Subject: [PATCH 06/16] support inconsistent input for bf16 BIOP legalize --- src/tir/transforms/bf16_legalize.cc | 85 ++++++++++------------------- 1 file changed, 28 insertions(+), 57 deletions(-) diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index b2213e1fefdf..79c406818185 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -45,26 +45,6 @@ class BF16PromoteRewriter : public StmtExprMutator { Stmt operator()(Stmt s) { return VisitStmt(s); } - std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, bool* is_bfloat16) { - auto a = this->VisitExpr(orig_a); - auto b = this->VisitExpr(orig_b); - *is_bfloat16 = false; - if (a->dtype.is_bfloat16()) { - ICHECK(b->dtype.is_bfloat16()); - *is_bfloat16 = true; - } else if (b->dtype.is_bfloat16()) { - ICHECK(a->dtype.is_bfloat16()); - *is_bfloat16 = true; - } - - if (*is_bfloat16) { - DataType fp32ty(kDLFloat, 32, 1); - a = Cast(fp32ty, a); - b = Cast(fp32ty, b); - } - return std::make_tuple(a, b); - } - PrimExpr VisitExpr_(const AddNode* op) final; PrimExpr VisitExpr_(const SubNode* op) final; PrimExpr VisitExpr_(const MulNode* op) final; @@ -77,45 +57,36 @@ class BF16PromoteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const GENode* op) final; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - if (!is_bfloat16) \ - return ret; \ - else \ - return Cast(DataType(kDLBfloat, 16, 1), ret); \ - } \ - } - -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a, b; \ - bool is_bfloat16; \ - std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - auto ret = FUNC(a, b); \ - return ret; \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST) \ + PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ + PrimExpr origin_a = this->VisitExpr(op->a); \ + PrimExpr origin_b = this->VisitExpr(op->b); \ + bool a_is_bfloat16 = origin_a->dtype.is_bfloat16(); \ + bool b_is_bfloat16 = origin_b->dtype.is_bfloat16(); \ + bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16; \ + bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16); \ + if (none_bfloat16) { \ + return GetRef(op); \ + } \ + DataType float32_dtype(kDLFloat, 32, 1); \ + PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) : origin_a; \ + PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) : origin_b; \ + PrimExpr result = FUNC(float32_a, float32_b); \ + DataType bfloat16_dtype(kDLBfloat, 16, 1); \ + bool do_cast = both_bfloat16 && NEEDCAST; \ + return do_cast ? Cast(bfloat16_dtype, result) : result; \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>) // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=) // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false) /* * Eliminate verbose casting between fp32 and bf16 From 8e0766ce6fcfae44b4fe6998c8ec18252b8b2ab0 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 22 Nov 2021 10:07:41 +0800 Subject: [PATCH 07/16] add treatments for bfloat16 in if statements --- .gitignore | 3 +++ src/arith/rewrite_simplify.cc | 2 +- src/auto_scheduler/feature.cc | 20 +++++++++---------- src/autotvm/touch_extractor.h | 20 ++++++++++++++----- src/contrib/hybrid/codegen_hybrid.cc | 3 +++ .../backend/contrib/codegen_c/codegen_c.h | 2 ++ src/relay/backend/utils.h | 2 ++ src/relay/op/nn/nn.cc | 3 ++- src/relay/transforms/pattern_utils.h | 6 ++++++ src/runtime/crt/common/packed_func.c | 3 +++ src/runtime/vm/bytecode.cc | 3 +++ src/tir/op/op.cc | 4 ++++ src/tir/transforms/lower_intrin.cc | 6 +++--- src/tir/transforms/make_packed_api.cc | 2 +- 14 files changed, 58 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index b2b6afb21544..420aacf0e343 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,10 @@ __pycache__/ .Python env/ build/ +build_debug/ +build_release/ develop-eggs/ +dev_tvm/ dist/ downloads/ eggs/ diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..55f0cf5f3929 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -461,7 +461,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { // x / 2.0 = x * 0.5 if (const FloatImmNode* ptr = op->b.as()) { - ICHECK(op->dtype.is_float() || + ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(op->dtype.code())); return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index aaf7d48b10c5..5809888543c6 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -246,14 +246,14 @@ int64_t GetLoopExtent(const ForNode* node) { // Count math ops in an expr class MathOpCounter : public StmtExprVisitor { public: -#define VisitBinary(Type, float_ct, int_ct) \ - void VisitExpr_(const Type* op) final { \ - if (op->a.dtype().is_float()) { \ - float_ct++; \ - } else { \ - int_ct++; \ - } \ - StmtExprVisitor::VisitExpr_(op); \ +#define VisitBinary(Type, float_ct, int_ct) \ + void VisitExpr_(const Type* op) final { \ + if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \ + float_ct++; \ + } else { \ + int_ct++; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ } VisitBinary(AddNode, float_addsub, int_addsub); @@ -299,13 +299,13 @@ class MathOpCounter : public StmtExprVisitor { effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; if (is_pure) { - if (op->dtype.is_float()) { + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { float_math_func++; } else { int_math_func++; } } else { - if (op->dtype.is_float()) { + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { float_other_func++; } else { int_other_func++; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 313e4d78d6e1..83260e1e0633 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -87,27 +87,37 @@ class TouchExtractor : public FeatureVisitor { // arithmetic stats void VisitExpr_(const AddNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].add_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const SubNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].add_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const MulNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].mul_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const DivNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].div_ct++; + } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const ModNode* op) final { - if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + itervar_map[itervar_stack_.back()].div_ct++; + } FeatureVisitor::VisitExpr_(op); } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54edbaee35cd..5872a49968cb 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -69,6 +69,9 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { } else if (t.is_int()) { os << "int"; ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); + } else if (t.is_bfloat16()) { + os << "bfloat"; + ICHECK(t.bits() == 16); } else { ICHECK(t.is_uint()) << "Unsupported type " << t; os << "uint"; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 3c6f810534ea..49a5bca068d1 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -363,6 +363,8 @@ class CodegenCBase { dtype = "float"; } else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) { dtype = "half"; + } else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) { + dtype = "bfloat"; } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { dtype = "int"; } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 658283b5dc36..02075616c6c8 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) { os << "int"; } else if (dtype.is_uint()) { os << "uint"; + } else if (dtype.is_bfloat16()) { + os << "bfloat"; } else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) { os << "custom[" << (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string() diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 89ef2708ff27..a959dd7e9915 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, << ", weights shape = " << weights->shape); return false; } - if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) { + if (!(predictions->dtype == weights->dtype && + (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "NLLLossRel: predictions and weights should" << " be of the same floating type."); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index d55ebdc20bc5..4084553419df 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -437,6 +437,12 @@ static inline dmlc::optional TryToScalar(const runtime::NDArray& ar } else if (array->dtype.bits == 64) { return dmlc::optional(reinterpret_cast(array->data)[i]); } + } else if (array->dtype.code == kDLBfloat) { + if (array->dtype.bits == 16) { + return dmlc::optional( + __extendXfYf2__( + reinterpret_cast(array->data)[i])); + } } return dmlc::optional(); } diff --git a/src/runtime/crt/common/packed_func.c b/src/runtime/crt/common/packed_func.c index e946cda9d9ae..645b22f3b255 100644 --- a/src/runtime/crt/common/packed_func.c +++ b/src/runtime/crt/common/packed_func.c @@ -49,6 +49,9 @@ DLDataType String2DLDataType(const char* s) { } else if (!strncmp(s, "float", 5)) { t.code = kDLFloat; scan = s + 5; + } else if (!strncmp(s, "bfloat", 6)) { + t.code = kDLBfloat; + scan = s + 6; } else if (!strncmp(s, "handle", 6)) { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index f83e27d2c11d..a2fa478ac6c8 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -497,6 +497,9 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { case kDLFloat: os << "float"; break; + case kDLBfloat: + os << "bfloat"; + break; } os << int(dtype.bits); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index d08bef2ab91a..577413777a6f 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -186,6 +186,8 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 16) { return FloatImm(dtype, 65504.0, span); } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::max(), span); } LOG(FATAL) << "Cannot decide max_value for type" << dtype; return PrimExpr(); @@ -219,6 +221,8 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 16) { return FloatImm(dtype, -65504.0, span); } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::lowest(), span); } LOG(FATAL) << "Cannot decide min_value for type" << dtype; return PrimExpr(); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 2555002d29b0..76f9fdff40f4 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -127,7 +127,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - if (dtype.is_float()) { + if (dtype.is_float() || dtype.is_bfloat16()) { // floor(a / b) return VisitExpr_(tvm::floor(op->a / op->b).as()); } else { @@ -181,7 +181,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - if (dtype.is_float()) { + if (dtype.is_float() || dtype.is_bfloat16()) { // a - floor(a / b) * b return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as()) * op->b); } else { @@ -269,7 +269,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr lhs = SwapBroadcastCast(a); PrimExpr rhs = SwapBroadcastCast(b); - if (fma_ != nullptr && op->dtype.is_float()) { + if (fma_ != nullptr && (op->dtype.is_float() || op->dtype.is_bfloat16())) { PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); if (r.defined()) return this->VisitExpr(r); } else { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d7e1beff03d3..136fc6008d6f 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -73,7 +73,7 @@ class ReturnRewriter : public StmtMutator { DataType dtype = val.dtype(); if (dtype.is_int() || dtype.is_uint()) { return {kTVMArgInt, Cast(DataType::Int(64), val)}; - } else if (dtype.is_float()) { + } else if (dtype.is_float() || dtype.is_bfloat16()) { return {kTVMArgFloat, Cast(DataType::Float(64), val)}; } else if (dtype.is_void()) { return {kTVMNullptr, val}; From 12fb7b19470a462c74f084b42e9a02e2dec0e067 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 15 Dec 2021 15:01:34 +0800 Subject: [PATCH 08/16] add bfloat16 dtype casts in binary OP --- src/tir/op/op.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 577413777a6f..330db9a6d17f 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -128,6 +128,14 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) !rtype.is_float()) { // Cast int->float when the other operand is a float rhs = cast(ltype, rhs); + } else if (!ltype.is_bfloat16() && + (rtype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { + // Cast int->float when the other operand is a float + lhs = cast(rtype, lhs); + } else if ((ltype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && + !rtype.is_bfloat16()) { + // Cast int->float when the other operand is a float + rhs = cast(ltype, rhs); } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (ltype.bits() < rtype.bits()) { From 9e32fde441b3c325e874e342ef4084ef3a955f05 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 15 Dec 2021 15:03:01 +0800 Subject: [PATCH 09/16] delete unnecessary treatments for bfloat16 --- src/tir/transforms/lower_intrin.cc | 6 +++--- src/tir/transforms/make_packed_api.cc | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 76f9fdff40f4..f22417b78a70 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -127,7 +127,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - if (dtype.is_float() || dtype.is_bfloat16()) { + if (dtype.is_float()) { // floor(a / b) return VisitExpr_(tvm::floor(op->a / op->b).as()); } else { @@ -181,7 +181,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - if (dtype.is_float() || dtype.is_bfloat16()) { + if (dtype.is_float()) { // a - floor(a / b) * b return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as()) * op->b); } else { @@ -269,7 +269,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr lhs = SwapBroadcastCast(a); PrimExpr rhs = SwapBroadcastCast(b); - if (fma_ != nullptr && (op->dtype.is_float() || op->dtype.is_bfloat16())) { + if (fma_ != nullptr && (op->dtype.is_float())) { PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); if (r.defined()) return this->VisitExpr(r); } else { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 136fc6008d6f..d7e1beff03d3 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -73,7 +73,7 @@ class ReturnRewriter : public StmtMutator { DataType dtype = val.dtype(); if (dtype.is_int() || dtype.is_uint()) { return {kTVMArgInt, Cast(DataType::Int(64), val)}; - } else if (dtype.is_float() || dtype.is_bfloat16()) { + } else if (dtype.is_float()) { return {kTVMArgFloat, Cast(DataType::Float(64), val)}; } else if (dtype.is_void()) { return {kTVMNullptr, val}; From 6c29073508ecc8d31a69338eaff5175256b9c695 Mon Sep 17 00:00:00 2001 From: yangulei Date: Thu, 20 Jan 2022 15:24:08 +0800 Subject: [PATCH 10/16] add test for bfloat16 building --- tests/python/relay/test_cpp_build_module.py | 30 +++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 23bc7ca95a34..ab9fa7aca442 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -93,6 +93,35 @@ def test_fp16_build(): np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5) +@tvm.testing.requires_llvm +def test_bf16_build(): + data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32') + weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32') + bn_gamma = relay.var("gamma", shape=(64,), dtype='float32') + bn_beta = relay.var("beta", shape=(64,), dtype='float32') + bn_mean = relay.var("mean", shape=(64,), dtype='float32') + bn_var = relay.var("var", shape=(64,), dtype='float32') + params = { + "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'), + "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'), + "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'), + "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'), + "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'), + } + conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'), + strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16') + bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'), + relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16')) + relu_bf16 = relay.nn.relu(bn_bf16[0]) + maxpool_bf16 = relay.nn.max_pool2d( + relu_bf16, pool_size=(2, 2), strides=(2, 2)) + avgpool_bf16 = relay.nn.avg_pool2d( + maxpool_bf16, pool_size=(2, 2), strides=(2, 2)) + mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16) + with tvm.transform.PassContext(opt_level=3): + relay.build(mod_bf16, target="llvm", params=params) + + @tvm.testing.parametrize_targets("llvm", "cuda") def test_fp16_conversion(target, dev): if target == "cuda" and not have_fp16(dev.compute_version): @@ -126,3 +155,4 @@ def test_fp16_conversion(target, dev): test_basic_build() test_fp16_build() test_fp16_conversion() + test_bf16_build() From 07ea8de87d80da05e1fb7e7936c9287c2c4f7211 Mon Sep 17 00:00:00 2001 From: yangulei Date: Fri, 21 Jan 2022 08:15:55 +0800 Subject: [PATCH 11/16] code style --- include/tvm/tir/op.h | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 2de3056f0bd1..60ad55102029 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -835,18 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s Span span = Span()); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - if (x.dtype().is_bfloat16()) { \ - DataType srcType = x.dtype(); \ - DataType dstType(kDLFloat, 32, srcType.lanes()); \ - PrimExpr castX = tir::Cast(dstType, {x}, span); \ - PrimExpr result = tir::Call(dstType, op, {castX}, span); \ - return tir::Cast(srcType, {result}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType srcType = x.dtype(); \ + DataType dstType(kDLFloat, 32, srcType.lanes()); \ + PrimExpr castX = tir::Cast(dstType, {x}, span); \ + PrimExpr result = tir::Call(dstType, op, {castX}, span); \ + return tir::Cast(srcType, {result}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp); From 5583b86d55194641f54e9c5ab300d6468265e084 Mon Sep 17 00:00:00 2001 From: yangulei Date: Fri, 21 Jan 2022 08:17:56 +0800 Subject: [PATCH 12/16] restore the modifications in .gitignore --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index 420aacf0e343..b2b6afb21544 100644 --- a/.gitignore +++ b/.gitignore @@ -11,10 +11,7 @@ __pycache__/ .Python env/ build/ -build_debug/ -build_release/ develop-eggs/ -dev_tvm/ dist/ downloads/ eggs/ From a526feeee4a8a136d7d18d8c7af687b811e89581 Mon Sep 17 00:00:00 2001 From: yangulei Date: Thu, 27 Jan 2022 14:29:42 +0800 Subject: [PATCH 13/16] restore the changes to AMP lists --- python/tvm/relay/transform/mixed_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 9bf6d7abd868..5018ba9ba9a7 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -40,8 +40,6 @@ "nn.conv3d_transpose", "nn.dense", "nn.batch_matmul", - "nn.bias_add", - "nn.batch_norm", ] DEFAULT_FOLLOW_LIST = [ # These ops add new data or change shape @@ -82,6 +80,8 @@ "subtract", "multiply", "divide", + "nn.bias_add", + "nn.batch_norm", "sqrt", "shape_of", # Simple activations From a1a3c31fcccc0ca609053540ccd1b3b36f7eec21 Mon Sep 17 00:00:00 2001 From: yangulei Date: Thu, 10 Feb 2022 09:37:49 +0800 Subject: [PATCH 14/16] fix typos --- src/tir/op/op.cc | 4 ++-- src/tir/transforms/lower_intrin.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 330db9a6d17f..1e7b041505fc 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -130,11 +130,11 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) rhs = cast(ltype, rhs); } else if (!ltype.is_bfloat16() && (rtype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { - // Cast int->float when the other operand is a float + // Cast int->bfloat16 when the other operand is a bfloat16 lhs = cast(rtype, lhs); } else if ((ltype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && !rtype.is_bfloat16()) { - // Cast int->float when the other operand is a float + // Cast int->bfloat16 when the other operand is a bfloat16 rhs = cast(ltype, rhs); } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index f22417b78a70..2555002d29b0 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -269,7 +269,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr lhs = SwapBroadcastCast(a); PrimExpr rhs = SwapBroadcastCast(b); - if (fma_ != nullptr && (op->dtype.is_float())) { + if (fma_ != nullptr && op->dtype.is_float()) { PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); if (r.defined()) return this->VisitExpr(r); } else { From 689eead12cd5412cf7902493bfaee9287f2817a9 Mon Sep 17 00:00:00 2001 From: yangulei Date: Fri, 18 Feb 2022 14:14:11 +0800 Subject: [PATCH 15/16] fix lint errors --- src/tir/op/op.cc | 6 ++- src/tir/transforms/arg_binder.cc | 4 +- tests/python/relay/test_cpp_build_module.py | 48 +++++++++++++-------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 1e7b041505fc..1a9a73e9dc94 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -129,10 +129,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) // Cast int->float when the other operand is a float rhs = cast(ltype, rhs); } else if (!ltype.is_bfloat16() && - (rtype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { + (rtype.is_bfloat16() || + datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { // Cast int->bfloat16 when the other operand is a bfloat16 lhs = cast(rtype, lhs); - } else if ((ltype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && + } else if ((ltype.is_bfloat16() || + datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && !rtype.is_bfloat16()) { // Cast int->bfloat16 when the other operand is a bfloat16 rhs = cast(ltype, rhs); diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index f4c2a4d4b1e0..1e566a980463 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -169,8 +169,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1) || dtype == DataType::UInt(16))) { + if (!(dtype == DataType::Int(1) || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || + dtype == DataType::UInt(16))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index ab9fa7aca442..fc6e7c2eeb28 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -95,28 +95,38 @@ def test_fp16_build(): @tvm.testing.requires_llvm def test_bf16_build(): - data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32') - weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32') - bn_gamma = relay.var("gamma", shape=(64,), dtype='float32') - bn_beta = relay.var("beta", shape=(64,), dtype='float32') - bn_mean = relay.var("mean", shape=(64,), dtype='float32') - bn_var = relay.var("var", shape=(64,), dtype='float32') + data = relay.var("data", shape=(1, 3, 224, 224), dtype="flaot32") + weight = relay.var("weight", shape=(64, 3, 7, 7), dtype="flaot32") + bn_gamma = relay.var("gamma", shape=(64,), dtype="flaot32") + bn_beta = relay.var("beta", shape=(64,), dtype="flaot32") + bn_mean = relay.var("mean", shape=(64,), dtype="flaot32") + bn_var = relay.var("var", shape=(64,), dtype="flaot32") params = { - "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'), - "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'), - "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'), - "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'), - "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'), + "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype("flaot32"), + "gamma": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), + "beta": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), + "mean": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), + "var": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), } - conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'), - strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16') - bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'), - relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16')) + conv_bf16 = relay.nn.conv2d( + relay.cast(data, "bfloat16"), + relay.cast(weight, "bfloat16"), + strides=(2, 2), + padding=(3, 3, 3, 3), + channels=64, + kernel_size=(7, 7), + out_dtype="bfloat16", + ) + bn_bf16 = relay.nn.batch_norm( + conv_bf16, + relay.cast(bn_gamma, "bfloat16"), + relay.cast(bn_beta, "bfloat16"), + relay.cast(bn_mean, "bfloat16"), + relay.cast(bn_var, "bfloat16"), + ) relu_bf16 = relay.nn.relu(bn_bf16[0]) - maxpool_bf16 = relay.nn.max_pool2d( - relu_bf16, pool_size=(2, 2), strides=(2, 2)) - avgpool_bf16 = relay.nn.avg_pool2d( - maxpool_bf16, pool_size=(2, 2), strides=(2, 2)) + maxpool_bf16 = relay.nn.max_pool2d(relu_bf16, pool_size=(2, 2), strides=(2, 2)) + avgpool_bf16 = relay.nn.avg_pool2d(maxpool_bf16, pool_size=(2, 2), strides=(2, 2)) mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16) with tvm.transform.PassContext(opt_level=3): relay.build(mod_bf16, target="llvm", params=params) From 09a66f42de76365150a1a47069428225c1e91ec4 Mon Sep 17 00:00:00 2001 From: yangulei Date: Mon, 21 Feb 2022 13:28:18 +0800 Subject: [PATCH 16/16] fix typo --- tests/python/relay/test_cpp_build_module.py | 22 ++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index fc6e7c2eeb28..ccf961fbe4db 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -95,18 +95,18 @@ def test_fp16_build(): @tvm.testing.requires_llvm def test_bf16_build(): - data = relay.var("data", shape=(1, 3, 224, 224), dtype="flaot32") - weight = relay.var("weight", shape=(64, 3, 7, 7), dtype="flaot32") - bn_gamma = relay.var("gamma", shape=(64,), dtype="flaot32") - bn_beta = relay.var("beta", shape=(64,), dtype="flaot32") - bn_mean = relay.var("mean", shape=(64,), dtype="flaot32") - bn_var = relay.var("var", shape=(64,), dtype="flaot32") + data = relay.var("data", shape=(1, 3, 224, 224), dtype="float32") + weight = relay.var("weight", shape=(64, 3, 7, 7), dtype="float32") + bn_gamma = relay.var("gamma", shape=(64,), dtype="float32") + bn_beta = relay.var("beta", shape=(64,), dtype="float32") + bn_mean = relay.var("mean", shape=(64,), dtype="float32") + bn_var = relay.var("var", shape=(64,), dtype="float32") params = { - "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype("flaot32"), - "gamma": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), - "beta": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), - "mean": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), - "var": np.random.uniform(-1, 1, size=(64,)).astype("flaot32"), + "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype("float32"), + "gamma": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + "beta": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + "mean": np.random.uniform(-1, 1, size=(64,)).astype("float32"), + "var": np.random.uniform(-1, 1, size=(64,)).astype("float32"), } conv_bf16 = relay.nn.conv2d( relay.cast(data, "bfloat16"),