Skip to content

Commit

Permalink
implement det sign
Browse files Browse the repository at this point in the history
  • Loading branch information
Jopyth committed Sep 5, 2018
1 parent b3be92f commit ada4ea1
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,14 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def det_sign(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`det_sign`.
The arguments are the same as for :py:func:`det_sign`, with
this array as data.
"""
return op.det_sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,14 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def det_sign(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`det_sign`.
The arguments are the same as for :py:func:`det_sign`, with
this array as data.
"""
return op.det_sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
Expand Down
17 changes: 17 additions & 0 deletions smd_hpi/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import mxnet as mx
import numpy as np
from mxnet import autograd
from mxnet.test_utils import assert_almost_equal


def test_det_sign():
exp_y = np.array([1.0, 1.0, -1.0])
exp_grad = np.array([1.0, 1.0, 1.0])

x = mx.nd.array([0.0, 0.6, -0.3])
x.attach_grad()
with autograd.record():
y = x.det_sign()
assert_almost_equal(exp_y, y.asnumpy())
y.backward()
assert_almost_equal(exp_grad, x.grad.asnumpy())
15 changes: 15 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,21 @@ struct sign : public mxnet_op::tunable {

MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));

/*! \brief used for generate element of sign */
struct det_sign : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Map(DType a) {
if (a < DType(0)) return DType(-DType(1));
return DType(1);
}
template<typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Map(DType a) {
return DType(1);
}
};

/*! \brief used for generate element of power */
MXNET_BINARY_MATH_OP(power, math::pow(a, b));

Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc); // NOLINT()
Expand Down
19 changes: 19 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,25 @@ The storage type of ``sign`` output depends upon the input storage type:

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sign, unary_bwd<mshadow_op::sign_grad>);

// det_sign
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP(det_sign, cpu, mshadow_op::det_sign)
MXNET_ADD_SPARSE_OP_ALIAS(det_sign)
.describe(R"code(Returns element-wise sign of the input (but with +1 for 0 values and Straigth Through Estimator).
Example::
det_sign([-2, 0, 3]) = [-1, 1, 1]
The storage type of ``det_sign`` output depends upon the input storage type:
- det_sign(default) = default
- det_sign(row_sparse) = row_sparse
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_det_sign"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_det_sign, unary_bwd<mshadow_op::identity_grad>);

// round
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(round, cpu, mshadow_op::round)
.describe(R"code(Returns element-wise rounded value to the nearest integer of the input.
Expand Down
9 changes: 9 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ NNVM_REGISTER_OP(_backward_sign)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
gpu, unary_bwd<mshadow_op::sign_grad> >);

// det_sign
NNVM_REGISTER_OP(det_sign)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::det_sign>)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::det_sign>);

NNVM_REGISTER_OP(_backward_det_sign)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
gpu, unary_bwd<mshadow_op::identity_grad> >);

// round
NNVM_REGISTER_OP(round)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)
Expand Down

0 comments on commit ada4ea1

Please sign in to comment.