forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Round and sign straight-through-estimators C operators. (apache#16373)
* Implemented round and sign straight-through-estimators C operators. * fuxed lint
- Loading branch information
1 parent
ec766d5
commit d5666ed
Showing
4 changed files
with
297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file stes_op.cc | ||
* \Straight-through-estimators round and sign operators. | ||
* \author Itay Golan | ||
*/ | ||
|
||
#include "stes_op.h" | ||
|
||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
// Round STE | ||
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(_contrib_round_ste, cpu, mshadow_op::round) | ||
.describe(R"code(Straight-through-estimator of `round()`. | ||
In forward pass, returns element-wise rounded value to the nearest integer of the input (same as `round()`). | ||
In backward pass, returns gradients of ``1`` everywhere (instead of ``0`` everywhere as in `round()`): | ||
:math:`\frac{d}{dx}{round\_ste(x)} = 1` vs. :math:`\frac{d}{dx}{round(x)} = 0`. | ||
This is useful for quantized training. | ||
Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. | ||
Example:: | ||
x = round_ste([-1.5, 1.5, -1.9, 1.9, 2.7]) | ||
x.backward() | ||
x = [-2., 2., -2., 2., 3.] | ||
x.grad() = [1., 1., 1., 1., 1.] | ||
The storage type of ``round_ste`` output depends upon the input storage type: | ||
- round_ste(default) = default | ||
- round_ste(row_sparse) = row_sparse | ||
- round_ste(csr) = csr | ||
)code" ADD_FILELINE) | ||
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_round_ste"}); | ||
|
||
// sign | ||
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(_contrib_sign_ste, cpu, mshadow_op::sign) | ||
.describe(R"code(Straight-through-estimator of `sign()`. | ||
In forward pass, returns element-wise sign of the input (same as `sign()`). | ||
In backward pass, returns gradients of ``1`` everywhere (instead of ``0`` everywhere as in ``sign()``): | ||
:math:`\frac{d}{dx}{sign\_ste(x)} = 1` vs. :math:`\frac{d}{dx}{sign(x)} = 0`. | ||
This is useful for quantized training. | ||
Reference: Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. | ||
Example:: | ||
x = sign_ste([-2, 0, 3]) | ||
x.backward() | ||
x = [-1., 0., 1.] | ||
x.grad() = [1., 1., 1.] | ||
The storage type of ``sign_ste`` output depends upon the input storage type: | ||
- round_ste(default) = default | ||
- round_ste(row_sparse) = row_sparse | ||
- round_ste(csr) = csr | ||
)code" ADD_FILELINE) | ||
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_sign_ste"}); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file stes_op.cu | ||
* \Straight-through-estimators round and sign operators. | ||
* \author Itay Golan | ||
*/ | ||
|
||
#include "stes_op.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
// Round STE | ||
NNVM_REGISTER_OP(_contrib_round_ste) | ||
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>) | ||
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::round>); | ||
|
||
// Sign STE | ||
NNVM_REGISTER_OP(_contrib_sign_ste) | ||
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sign>) | ||
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::sign>); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file stes_op.h | ||
* \Straight-through-estimators round and sign operators. | ||
* \author Itay Golan | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_CONTRIB_STES_OP_H_ | ||
#define MXNET_OPERATOR_CONTRIB_STES_OP_H_ | ||
|
||
#include <mxnet/base.h> | ||
#include "../tensor/elemwise_unary_op.h" | ||
|
||
#endif // MXNET_OPERATOR_CONTRIB_STES_OP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from common import with_seed | ||
import mxnet as mx | ||
from mxnet import nd, autograd, gluon | ||
from mxnet.test_utils import default_context | ||
|
||
|
||
class RoundSTENET(gluon.HybridBlock): | ||
def __init__(self, w_init, **kwargs): | ||
super(RoundSTENET, self).__init__(**kwargs) | ||
with self.name_scope(): | ||
self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') | ||
|
||
@staticmethod | ||
def expected_grads(in_data, w_init): | ||
return (in_data * w_init).round() + (in_data * w_init) | ||
|
||
@staticmethod | ||
def expected_output(in_data, w_init): | ||
return (in_data * w_init).round() * w_init | ||
|
||
def hybrid_forward(self, F, x, w): | ||
# Simple forward function: round_ste(w*x)*w | ||
out = w * x | ||
out = F.contrib.round_ste(out) | ||
# Uncomment to see how test fails with round | ||
# out = F.round(out) | ||
out = out * w | ||
return out | ||
|
||
|
||
class SignSTENET(gluon.HybridBlock): | ||
def __init__(self, w_init, **kwargs): | ||
super(SignSTENET, self).__init__(**kwargs) | ||
with self.name_scope(): | ||
self.w = self.params.get('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write') | ||
|
||
@staticmethod | ||
def expected_grads(in_data, w_init): | ||
return (in_data * w_init).sign() + (in_data * w_init) | ||
|
||
@staticmethod | ||
def expected_output(in_data, w_init): | ||
return (in_data * w_init).sign() * w_init | ||
|
||
def hybrid_forward(self, F, x, w): | ||
# Simple forward function: sign_ste(w*x)*w | ||
out = w * x | ||
out = F.contrib.sign_ste(out) | ||
# Uncomment to see how test fails with sign | ||
# out = F.sign(out) | ||
out = out * w | ||
return out | ||
|
||
|
||
def check_ste(net_type_str, w_init, hybridize, in_data, ctx=None): | ||
ctx = ctx or default_context() | ||
|
||
net = eval(net_type_str)(w_init=w_init) | ||
if hybridize: | ||
net.hybridize() | ||
# Init | ||
net.collect_params().initialize(mx.init.Constant([w_init]), ctx=ctx) | ||
|
||
# Test: | ||
in_data = in_data.as_in_context(ctx) | ||
with mx.autograd.record(): | ||
out = net(in_data) | ||
assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \ | ||
" expected " + str(net.expected_output(in_data, w_init)) | ||
|
||
out.backward() | ||
assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \ | ||
str(net.w.grad()) + " but expected " + \ | ||
str(net.expected_grads(in_data, w_init)) | ||
with mx.autograd.record(): | ||
out = net(in_data) | ||
assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \ | ||
" expected " + str(net.expected_output(in_data, w_init)) | ||
out.backward() | ||
assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \ | ||
str(net.w.grad()) + " but expected " + \ | ||
str(net.expected_grads(in_data, w_init)) | ||
|
||
@with_seed() | ||
def test_contrib_round_ste(): | ||
# Test with random data | ||
in_data = nd.uniform(-10, 10, shape=30) # 10 and 30 are arbitrary numbers | ||
w_init = float(nd.uniform(-10, 10, shape=1).asscalar()) | ||
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data) | ||
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data) | ||
|
||
# Test 1.5 (verifies that .5 rounds the same as in round) | ||
in_data = nd.array([1.5]*30) # 10 and 30 are arbitrary numbers | ||
w_init = 1. | ||
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data) | ||
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data) | ||
|
||
# Test 0 | ||
in_data = nd.array([0]*30) # 10 and 30 are arbitrary numbers | ||
w_init = 0. | ||
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data) | ||
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data) | ||
|
||
|
||
@with_seed() | ||
def test_contrib_sign_ste(): | ||
in_data = nd.uniform(-10, 10, shape=30) # 10 and 30 are arbitrary numbers | ||
w_init = float(nd.uniform(-10, 10, shape=1).asscalar()) | ||
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data) | ||
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data) | ||
|
||
# Test 0 | ||
in_data = nd.array([0]*30) # 10 and 30 are arbitrary numbers | ||
w_init = 0. | ||
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data) | ||
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data) | ||
|
||
if __name__ == '__main__': | ||
import nose | ||
nose.runmodule() |