diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 9870342ea402..a0c4149d5c5f 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -83,7 +83,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::hypot_grad_left>); -MXNET_OPERATOR_REGISTER_BINARY_SCALAR(smooth_l1) +NNVM_REGISTER_OP(smooth_l1) .describe(R"code(Calculate Smooth L1 Loss(lhs, scalar) by summing .. math:: @@ -98,17 +98,40 @@ where :math:`x` is an element of the tensor *lhs* and :math:`\sigma` is the scal Example:: + smooth_l1([1, 2, 3, 4]) = [0.5, 1.5, 2.5, 3.5] smooth_l1([1, 2, 3, 4], scalar=1) = [0.5, 1.5, 2.5, 3.5] )code" ADD_FILELINE) -.set_attr("FCompute", BinaryScalarOp::Compute< - cpu, mshadow_op::smooth_l1_loss>) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + if (attrs->dict.find("scalar") != attrs->dict.end()) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + } else { + attrs->parsed = 1.0; + } + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "float", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_smooth_l1" }); MXNET_OPERATOR_REGISTER_BINARY(_backward_smooth_l1) -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", BinaryScalarOp::Backward< - cpu, mshadow_op::smooth_l1_gradient>); + .set_attr_parser([](NodeAttrs *attrs) { + if (attrs->dict.find("scalar") != attrs->dict.end()) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + } else { + attrs->parsed = 1.0; + } +}) +.set_attr("FCompute", + BinaryScalarOp::Backward); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 9842a69e18d4..55a46ca2e93c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5956,6 +5956,10 @@ def test_unary_math_operators(): lambda x: np_smooth_l1(x, 1.), lambda x: np_smooth_l1_grad(x, 1.), -2.0, 2.0], + 'smooth_l1_sig_default': [lambda x: mx.sym.smooth_l1(x), + lambda x: np_smooth_l1(x, 1.), + lambda x: np_smooth_l1_grad(x, 1.), + -2.0, 2.0], 'smooth_l1_sig2': [lambda x: mx.sym.smooth_l1(x, scalar=2.), lambda x: np_smooth_l1(x, 2.), lambda x: np_smooth_l1_grad(x, 2.),