Skip to content

Commit

Permalink
[INT8][BF16] INT8 + BF16 feature was enabled (openvinotoolkit#5059)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexey-varyzgin authored May 3, 2021
1 parent bb022e2 commit 7d2ec02
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 6 deletions.
7 changes: 5 additions & 2 deletions inference-engine/src/mkldnn_plugin/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,15 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
dumpQuantizedGraphToIr = val;
} else if (key == PluginConfigParams::KEY_ENFORCE_BF16) {
if (val == PluginConfigParams::YES) {
if (with_cpu_x86_avx512_core())
if (with_cpu_x86_avx512_core()) {
enforceBF16 = true;
else
manualEnforceBF16 = true;
} else {
IE_THROW() << "Platform doesn't support BF16 format";
}
} else if (val == PluginConfigParams::NO) {
enforceBF16 = false;
manualEnforceBF16 = false;
} else {
IE_THROW() << "Wrong value for property key " << PluginConfigParams::KEY_ENFORCE_BF16
<< ". Expected only YES/NO";
Expand Down
1 change: 1 addition & 0 deletions inference-engine/src/mkldnn_plugin/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct Config {
#else
LPTransformsMode lpTransformsMode = LPTransformsMode::On;
bool enforceBF16 = true;
bool manualEnforceBF16 = false;
#endif

void readProperties(const std::map<std::string, std::string> &config);
Expand Down
13 changes: 9 additions & 4 deletions inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network,
bool isFloatModel = true;
if (_cfg.lpTransformsMode == Config::LPTransformsMode::On) {
// Check if network is INT8 or Binary.
// BF16 transformations were disabled since CPU plug-in doesn't support mixed precision execution:
// BF16 + INT8 or BF16 + BIN.
CNNNetworkIterator iter(network);
while (iter != CNNNetworkIterator()) {
if (CaselessEq<std::string>()((*iter)->type, "FakeQuantize")) {
Expand Down Expand Up @@ -87,12 +85,19 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network,
}
};

if (with_cpu_x86_avx512_core() && isFloatModel) {
if (with_cpu_x86_avx512_core()) {
// If enforceBF16 flag was set, BF16 transformation applies for all layers supported by CPU plugin.
// Otherwise, only layers marked as BF16 in '_clonedNetwork' will be performed in bfloat16 mode.
// CPU plugin throws an exception, if marked as BF16 layers have not supported by CPU plugin.
if (cfg.enforceBF16 == true)

// BF16 + INT8 or BF16 + BIN models will be performed in mixed precision execution only if
// enforceBF16 flag was set manually
if (isFloatModel == false) {
if (cfg.manualEnforceBF16 == true)
changePrecisionBF16(Precision::FP32, Precision::BF16);
} else if (cfg.enforceBF16 == true) {
changePrecisionBF16(Precision::FP32, Precision::BF16);
}
} else {
changePrecisionBF16(Precision::BF16, Precision::FP32);
}
Expand Down
29 changes: 29 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,15 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) {
}
}

static bool BF16QuantizeNodeFusing(MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) {
return childNode->getType() == Quantize &&
one_of(Precision::BF16,
parentNode->getCnnLayer()->precision,
childNode->getCnnLayer()->precision,
parentNode->getCnnLayer()->outData[0].get()->getPrecision(),
childNode->getCnnLayer()->outData[0].get()->getPrecision());
}

void MKLDNNGraphOptimizer::FuseFullyConnectedAndSimpleOperation(MKLDNNGraph &graph) {
auto& graphNodes = graph.GetNodes();

Expand Down Expand Up @@ -754,6 +763,12 @@ void MKLDNNGraphOptimizer::FuseFullyConnectedAndSimpleOperation(MKLDNNGraph &gra
continue;
}

// BF16 Quantize Layer Fusing Disabling
if (BF16QuantizeNodeFusing(parentNode, childNode)) {
parent++;
continue;
}

parentNode->fuseWith(childNode);

if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
Expand Down Expand Up @@ -1011,6 +1026,10 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndQuantize(MKLDNNGraph &graph) {
auto child = parent->getChildEdgeAt(0)->getChild();
if (!isSutableChildNode(child)) continue;

// BF16 Quantize Layer Fusing Disabling
if (BF16QuantizeNodeFusing(parent, child))
continue;

parent->fuseWith(child);

auto parents = child->parentEdges;
Expand Down Expand Up @@ -1073,6 +1092,12 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph)
continue;
}

// BF16 Quantize Layer Fusing Disabling
if (BF16QuantizeNodeFusing(parentNode, childNode)) {
parent++;
continue;
}

parentNode->fuseWith(childNode);

if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
Expand Down Expand Up @@ -1117,6 +1142,10 @@ void MKLDNNGraphOptimizer::FuseBinaryConvolutionAndQuantize(MKLDNNGraph &graph)
auto child = parent->getChildEdgeAt(0)->getChild();
if (!isSutableChildNode(parent, child)) continue;

// BF16 Quantize Layer Fusing Disabling
if (BF16QuantizeNodeFusing(parent, child))
continue;

parent->fuseWith(child);

auto parents = child->parentEdges;
Expand Down
5 changes: 5 additions & 0 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {

MKLDNNMemoryDesc in_candidate, out_candidate;
if (canBeExecutedInInt8()) {
// We have to extend convolution_x8s8s32x from oneDNN to support BF16 output data type
if (outputDataType == memory::data_type::bf16)
outputDataType = memory::data_type::f32;
if (eltwisePrecision == Precision::BF16)
eltwisePrecision = Precision::FP32;
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::format_tag::ndhwc : memory::format_tag::nhwc);
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,24 @@ void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
}
auto weightsDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(getCnnLayer()->insData[1].lock()->getPrecision());

// We have to extend gemm_x8s8s32x_inner_product_fwd_t from oneDNN to support BF16 output data type
if ((!one_of(inputDataType , memory::data_type::u8, memory::data_type::s8) || weightsDataType != memory::data_type::s8) &&
inputDataType != memory::data_type::bf16) {
inputDataType = memory::data_type::f32;
outputDataType = memory::data_type::f32;
}
}

if (one_of(inputDataType , memory::data_type::u8, memory::data_type::s8)
&& outputDataType == memory::data_type::bf16) {
outputDataType = memory::data_type::f32;
}

if (inputDataType == memory::data_type::bf16
&& one_of(outputDataType , memory::data_type::u8, memory::data_type::s8)) {
outputDataType = memory::data_type::bf16;
}

auto * fcLayer = dynamic_cast<FullyConnectedLayer*>(getCnnLayer().get());
if (fcLayer == nullptr)
IE_THROW() << "Cannot convert fully connected layer.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ void MKLDNNPoolingNode::getSupportedDescriptors() {
effective_pad_end[i] = (dst - calc_dst) * stride[i];
}
if (inputPrecision == Precision::I8 || inputPrecision == Precision::U8) {
// We have to extend i8i8_pooling_fwd_t from oneDNN to support BF16 output data type
if (outputDataType == memory::data_type::bf16)
outputDataType = memory::data_type::f32;
// i8 layers supports only ndhwc and nhwc layouts
MKLDNNMemoryDesc in_candidate{parentDims, inputDataType, parentDims.ndims() == 5 ? memory::format_tag::ndhwc : memory::format_tag::nhwc};
MKLDNNMemoryDesc out_candidate{childDims, outputDataType, parentDims.ndims() == 5 ? memory::format_tag::ndhwc : memory::format_tag::nhwc};
Expand Down

0 comments on commit 7d2ec02

Please sign in to comment.