-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add int8 bn mkldnn implementation and test #15664
Conversation
@mxnet-label-bot add [mkldnn, Backend, pr-awaiting-review] |
@ElaineBao Can you try this on resnetv2? Theoretically, the performance will be better since lots of bn-relu-conv pattern in this model. |
@ZhennanQin @ciyongch, please help take a review. |
@ElaineBao could you elaborate the reason for standalone BN leads a bit more accuracy dorp? |
that's a good advice, I'll try it and update the performance, thank you. |
Basically the accuracy drop is not because the BN is fused or standalone, it's because the BN is converted from fp32 to int8. |
@ElaineBao unfuse bn will also introduce standalone quanitzed_activation along with quantized_bn. |
Hi, all, I've looked into the performance issue, and concluded that as a operator, int8 bn itself has no performance regression. The accuracy drop happened in some models is due to the combination of int8 bn and other operators, which may cause a poor weight distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@ZhennanQin please take a review too. |
src/operator/nn/batch_norm.cc
Outdated
@@ -396,7 +396,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, | |||
CHECK_EQ(inputs.size(), 5U); | |||
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed); | |||
// MKLDNN batchnorm only works well on the special MKLDNN layout. | |||
if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) { | |||
if (SupportMKLDNNBN(inputs[0], param) /*&& inputs[0].IsMKLDNNData() */) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we can remove this. @TaoLv for double check.
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { | ||
*max_output_ptr = | ||
std::max(std::abs(param.min_calib_range.value()), std::abs(param.max_calib_range.value())); | ||
*min_output_ptr = -*max_output_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not *min_output_ptr = param.min_calib_range.value()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, if min_calib_range=-5, max_calib_range=10, then *max_output_ptr=10, *min_output_ptr=-10, it;s symmetric.
CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2); | ||
float *weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle()); | ||
|
||
NDArray gamma = in_data[quantized_batchnorm::kGamma]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to const NDArray& gamma
float *weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle()); | ||
|
||
NDArray gamma = in_data[quantized_batchnorm::kGamma]; | ||
NDArray beta = in_data[quantized_batchnorm::kBeta]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above.
|
||
NDArray gamma = in_data[quantized_batchnorm::kGamma]; | ||
NDArray beta = in_data[quantized_batchnorm::kBeta]; | ||
float *gamma_ptr = gamma.data().dptr<float>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems gamma only used here. Perhaps you can remove it as
float *gamma_ptr = in_data[quantized_batchnorm::kGamma].data().dptr<float>();
float *moving_var_ptr = moving_var.data().dptr<float>(); | ||
|
||
// rescale gamma and beta, to make mean=0 and var=1 | ||
NDArray rescaled_mean = NDArray(moving_mean.storage_type(), moving_mean.shape(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create temp memory from TmpMemMgr
} | ||
|
||
const NDArray &out = outputs[batchnorm::kOut]; | ||
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use CreateMKLDNNMem
instead. Avoid using const_cast
unless you have to.
if (!dispatched) { | ||
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); | ||
} | ||
if (!MKLDNNEnvSet()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary, MKLDNNStorageType
will check this.
@@ -175,6 +175,47 @@ class MKLDNNBNForward { | |||
} | |||
} | |||
|
|||
void SetDataHandle(const NDArray &data, const mkldnn::memory *mean, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't duplicate code. Make old version on top of this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for the great works :) Merging now. |
* add int8 bn mkldnn implementation and test * fix lint * fix ci * enable int8 bn test only in mkldnn backend * disable int8 bn forward test with gpu backend * update int8 bn with reference to comments * fix lint * disable int8 bn gluon forward test with gpu backend * disable uint8 bn forward test with mkldnn backend * restore support mkldnn bn condition * rm duplicate code
Description
Add a new operator - int8 batch norm, mkldnn implementation and test
@pengzhao-intel @ZhennanQin
Details
Usage
export MXNET_DISABLE_MKLDNN_FUSE_CONV_BN=1
before usingimagenet_gen_qsym_mkldnn.py
to quantize the model.Limitation
calib_mode = none
, since when calculating the thresholds on the fly with s8 input, errors are large. One can run withcalib_mode=naïve/entropy
, should have a similar accuracy with fp32 model.Performance
I tested several models on skylake, which can be used for reference.