Skip to content

Commit

Permalink
Round and sign straight-through-estimators C operators. (apache#16373)
Browse files Browse the repository at this point in the history
* Implemented round and sign straight-through-estimators C operators.

* fuxed lint
  • Loading branch information
igolan authored and sxjscience committed Oct 8, 2019
1 parent ec766d5 commit d5666ed
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 0 deletions.
84 changes: 84 additions & 0 deletions src/operator/contrib/stes_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file stes_op.cc
* \Straight-through-estimators round and sign operators.
* \author Itay Golan
*/

#include "stes_op.h"


namespace mxnet {
namespace op {

// Round STE
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(_contrib_round_ste, cpu, mshadow_op::round)
.describe(R"code(Straight-through-estimator of `round()`.
In forward pass, returns element-wise rounded value to the nearest integer of the input (same as `round()`).
In backward pass, returns gradients of ``1`` everywhere (instead of ``0`` everywhere as in `round()`):
:math:`\frac{d}{dx}{round\_ste(x)} = 1` vs. :math:`\frac{d}{dx}{round(x)} = 0`.
This is useful for quantized training.
Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.
Example::
x = round_ste([-1.5, 1.5, -1.9, 1.9, 2.7])
x.backward()
x = [-2., 2., -2., 2., 3.]
x.grad() = [1., 1., 1., 1., 1.]
The storage type of ``round_ste`` output depends upon the input storage type:
- round_ste(default) = default
- round_ste(row_sparse) = row_sparse
- round_ste(csr) = csr
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_round_ste"});

// sign
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(_contrib_sign_ste, cpu, mshadow_op::sign)
.describe(R"code(Straight-through-estimator of `sign()`.
In forward pass, returns element-wise sign of the input (same as `sign()`).
In backward pass, returns gradients of ``1`` everywhere (instead of ``0`` everywhere as in ``sign()``):
:math:`\frac{d}{dx}{sign\_ste(x)} = 1` vs. :math:`\frac{d}{dx}{sign(x)} = 0`.
This is useful for quantized training.
Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.
Example::
x = sign_ste([-2, 0, 3])
x.backward()
x = [-1., 0., 1.]
x.grad() = [1., 1., 1.]
The storage type of ``sign_ste`` output depends upon the input storage type:
- round_ste(default) = default
- round_ste(row_sparse) = row_sparse
- round_ste(csr) = csr
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_sign_ste"});

} // namespace op
} // namespace mxnet
43 changes: 43 additions & 0 deletions src/operator/contrib/stes_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file stes_op.cu
* \Straight-through-estimators round and sign operators.
* \author Itay Golan
*/

#include "stes_op.h"

namespace mxnet {
namespace op {

// Round STE
NNVM_REGISTER_OP(_contrib_round_ste)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::round>);

// Sign STE
NNVM_REGISTER_OP(_contrib_sign_ste)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sign>)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::sign>);

} // namespace op
} // namespace mxnet
33 changes: 33 additions & 0 deletions src/operator/contrib/stes_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file stes_op.h
* \Straight-through-estimators round and sign operators.
* \author Itay Golan
*/

#ifndef MXNET_OPERATOR_CONTRIB_STES_OP_H_
#define MXNET_OPERATOR_CONTRIB_STES_OP_H_

#include <mxnet/base.h>
#include "../tensor/elemwise_unary_op.h"

#endif // MXNET_OPERATOR_CONTRIB_STES_OP_H_
137 changes: 137 additions & 0 deletions tests/python/unittest/test_contrib_stes_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from common import with_seed
import mxnet as mx
from mxnet import nd, autograd, gluon
from mxnet.test_utils import default_context


class RoundSTENET(gluon.HybridBlock):
def __init__(self, w_init, **kwargs):
super(RoundSTENET, self).__init__(**kwargs)
with self.name_scope():
self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write')

@staticmethod
def expected_grads(in_data, w_init):
return (in_data * w_init).round() + (in_data * w_init)

@staticmethod
def expected_output(in_data, w_init):
return (in_data * w_init).round() * w_init

def hybrid_forward(self, F, x, w):
# Simple forward function: round_ste(w*x)*w
out = w * x
out = F.contrib.round_ste(out)
# Uncomment to see how test fails with round
# out = F.round(out)
out = out * w
return out


class SignSTENET(gluon.HybridBlock):
def __init__(self, w_init, **kwargs):
super(SignSTENET, self).__init__(**kwargs)
with self.name_scope():
self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write')

@staticmethod
def expected_grads(in_data, w_init):
return (in_data * w_init).sign() + (in_data * w_init)

@staticmethod
def expected_output(in_data, w_init):
return (in_data * w_init).sign() * w_init

def hybrid_forward(self, F, x, w):
# Simple forward function: sign_ste(w*x)*w
out = w * x
out = F.contrib.sign_ste(out)
# Uncomment to see how test fails with sign
# out = F.sign(out)
out = out * w
return out


def check_ste(net_type_str, w_init, hybridize, in_data, ctx=None):
ctx = ctx or default_context()

net = eval(net_type_str)(w_init=w_init)
if hybridize:
net.hybridize()
# Init
net.collect_params().initialize(mx.init.Constant([w_init]), ctx=ctx)

# Test:
in_data = in_data.as_in_context(ctx)
with mx.autograd.record():
out = net(in_data)
assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \
" expected " + str(net.expected_output(in_data, w_init))

out.backward()
assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \
str(net.w.grad()) + " but expected " + \
str(net.expected_grads(in_data, w_init))
with mx.autograd.record():
out = net(in_data)
assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \
" expected " + str(net.expected_output(in_data, w_init))
out.backward()
assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \
str(net.w.grad()) + " but expected " + \
str(net.expected_grads(in_data, w_init))

@with_seed()
def test_contrib_round_ste():
# Test with random data
in_data = nd.uniform(-10, 10, shape=30) # 10 and 30 are arbitrary numbers
w_init = float(nd.uniform(-10, 10, shape=1).asscalar())
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data)

# Test 1.5 (verifies that .5 rounds the same as in round)
in_data = nd.array([1.5]*30) # 10 and 30 are arbitrary numbers
w_init = 1.
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data)

# Test 0
in_data = nd.array([0]*30) # 10 and 30 are arbitrary numbers
w_init = 0.
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data)


@with_seed()
def test_contrib_sign_ste():
in_data = nd.uniform(-10, 10, shape=30) # 10 and 30 are arbitrary numbers
w_init = float(nd.uniform(-10, 10, shape=1).asscalar())
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data)

# Test 0
in_data = nd.array([0]*30) # 10 and 30 are arbitrary numbers
w_init = 0.
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit d5666ed

Please sign in to comment.