diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h index de8a2569261d..d0881086ac97 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h @@ -33,9 +33,13 @@ namespace op { static inline bool SupportMKLDNNFCEltwiseFusion(const std::string op_name) { if (op_name == "Activation" || op_name == "square" || + op_name == "_npi_square" || op_name == "sqrt" || + op_name == "_npi_sqrt" || op_name == "exp" || + op_name == "_npi_exp" || op_name == "abs" || + op_name == "_npi_absolute" || op_name == "clip" || op_name == "LeakyReLU") { return true; @@ -45,13 +49,13 @@ static inline bool SupportMKLDNNFCEltwiseFusion(const std::string op_name) { } static inline mkldnn::algorithm GetMKLDNNEltwiseAlgo(const std::string op_name) { - if (op_name == "square") + if (op_name == "square" || op_name == "_npi_square") return mkldnn::algorithm::eltwise_square; - else if (op_name == "sqrt") + else if (op_name == "sqrt" || op_name == "_npi_sqrt") return mkldnn::algorithm::eltwise_sqrt; - else if (op_name == "exp") + else if (op_name == "exp" || op_name == "_npi_exp") return mkldnn::algorithm::eltwise_exp; - else if (op_name == "abs") + else if (op_name == "abs" || op_name == "_npi_absolute") return mkldnn::algorithm::eltwise_abs; else LOG(FATAL) << "Unsupported eltwise fusion op: " << op_name; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index 9a0c7770a22d..a96c71153869 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -113,13 +113,17 @@ class SgMKLDNNFCSelector : public SubgraphSelector { } } if (!quantized_ && (new_node.op() == Op::Get("square") || - new_node.op() == Op::Get("sqrt") || - new_node.op() == Op::Get("exp"))) { + new_node.op() == Op::Get("_npi_square") || + new_node.op() == Op::Get("sqrt") || + new_node.op() == Op::Get("_npi_sqrt") || + new_node.op() == Op::Get("exp") || + new_node.op() == Op::Get("_npi_exp"))) { matched_list_.push_back(&new_node); status_ = kSuccess; return true; } - if (new_node.op() == Op::Get("abs")) { + if (new_node.op() == Op::Get("abs") || + new_node.op() == Op::Get("_npi_absolute")) { matched_list_.push_back(&new_node); status_ = kSuccess; return true; diff --git a/tests/python/mkl/subgraphs/subgraph_common.py b/tests/python/mkl/subgraphs/subgraph_common.py index 9f518414ac91..4467443166cb 100644 --- a/tests/python/mkl/subgraphs/subgraph_common.py +++ b/tests/python/mkl/subgraphs/subgraph_common.py @@ -71,13 +71,16 @@ class CustomNormalInit(mx.init.Initializer): """Initializes weights with random values sampled from a normal distribution with a custom mean and standard deviation of `sigma`. """ - def __init__(self, mean=0, sigma=0.01): - super(CustomNormalInit, self).__init__(mean=mean, sigma=sigma) + def __init__(self, mean=0, sigma=0.01, bounded=False): + super(CustomNormalInit, self).__init__(mean=mean, sigma=sigma, bounded=bounded) self.mean = mean self.sigma = sigma + self.bounded = bounded def _init_weight(self, _, arr): mx.np.random.normal(self.mean, self.sigma, arr.shape, dtype=arr.dtype, out=arr) + if self.bounded: + mx.np.abs(arr, out=arr) def check_qsym_calibrated(qsym, out_type, name='conv'): diff --git a/tests/python/mkl/subgraphs/test_fc_subgraph.py b/tests/python/mkl/subgraphs/test_fc_subgraph.py index 1bcd332e3b8c..5c0aee7500c1 100644 --- a/tests/python/mkl/subgraphs/test_fc_subgraph.py +++ b/tests/python/mkl/subgraphs/test_fc_subgraph.py @@ -59,18 +59,19 @@ def forward(self, x): @pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize('flatten', [True, False]) @pytest.mark.parametrize('alg', fc_post_ops_list) -@pytest.mark.skip("Operator square, square_root, abs, exp cannot be found in numpy mode") def test_fc_eltwise(data_shape, use_bias, flatten, alg): # fc + eltwise fusion case class FCEltwise(nn.HybridBlock): def __init__(self, use_bias, flatten, alg, **kwargs): super(FCEltwise, self).__init__(**kwargs) self.fc = nn.Dense(units=64, use_bias=use_bias, flatten=flatten, - weight_initializer=CustomNormalInit(mean=0.5, sigma=0.1) if alg == 'square_root' else None) + weight_initializer=CustomNormalInit(mean=0.5, sigma=0.1, bounded=True) if alg == 'square_root' else None) #avoid calculating square root of negative values self.alg = alg def forward(self, x): + if self.alg == 'square_root': + x = abs(x) fc_out = self.fc(x) if self.alg in ['relu', 'sigmoid', 'log_sigmoid', 'mish', 'tanh', 'softrelu']: out = mx.npx.activation(fc_out, act_type=self.alg)