Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Aggregated zero grad (#16446)
Browse files Browse the repository at this point in the history
* Trigger CI

* Aggregated zeroing of the gradients/arrays

* New files for aggregated zeroing of the gradients/arrays

* Adding possibility to reset the arrays of different types.

* Minor cleanup
  • Loading branch information
drivanov authored and apeforest committed Nov 6, 2019
1 parent 8c22fac commit 82ed82f
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 12 deletions.
21 changes: 18 additions & 3 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
'ParameterDict', 'tensor_types']


from collections import OrderedDict
from collections import OrderedDict, defaultdict
import warnings
import numpy as np
import mxnet as mx

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, context
Expand Down Expand Up @@ -887,8 +888,22 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,

def zero_grad(self):
"""Sets all Parameters' gradient buffer to 0."""
for i in self.values():
i.zero_grad()
# collect gradient arrays for each ctx
arrays = defaultdict(list)
for p in self.values():
if p.grad_req == 'null' or p._grad is None:
continue
for g in p.list_grad():
if g.stype == 'row_sparse':
mx.ndarray.zeros_like(g, out=g)
else:
arrays[g.context].append(g)

if len(arrays) == 0:
return

for arr in arrays.values():
mx.nd.reset_arrays(*arr, num_arrays=len(arr))

def reset_ctx(self, ctx):
"""Re-assign all Parameters to other contexts.
Expand Down
92 changes: 92 additions & 0 deletions src/operator/contrib/reset_arrays-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file reset_arrays-inl.h
* \brief setting all array element values to zeros
* \author Moises Hernandez-Fernandez, Andrei Ivanov
*/

#ifndef MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_
#define MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_

#include <vector>
#include "../tensor/init_op.h"

namespace mxnet {
namespace op {

struct ResetArraysParam : public dmlc::Parameter<ResetArraysParam> {
int num_arrays;
DMLC_DECLARE_PARAMETER(ResetArraysParam) {
DMLC_DECLARE_FIELD(num_arrays)
.describe("number of input arrays.");
}
};

inline bool ResetArraysShape(const NodeAttrs& attrs,
std::vector<mxnet::TShape>* in_shape,
std::vector<mxnet::TShape>* out_shape) {
const auto& param = dmlc::get<ResetArraysParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), param.num_arrays);
for (auto s : *in_shape) {
if (s.ndim() == 0)
return false;
}

return true;
}

inline bool ResetArraysType(const NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const auto& param = dmlc::get<ResetArraysParam>(attrs.parsed);
CHECK_EQ(in_type->size(), param.num_arrays);
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1)
return false;
}

return true;
}

template<typename xpu>
void ResetMemory(void *pntr, size_t len, mshadow::Stream<xpu> *s);

template<typename xpu>
void ResetArrays(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
auto s = ctx.get_stream<xpu>();
const auto& param = nnvm::get<ResetArraysParam>(attrs.parsed);
for (int i = 0; i < param.num_arrays; i++) { // array index in inputs
const size_t size = inputs[i].shape_.Size();
MSHADOW_REAL_TYPE_SWITCH(inputs[i].type_flag_, DType,
ResetMemory(inputs[i].FlatTo2D<xpu, DType>(s).dptr_, size * sizeof(DType), s);
)
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_
74 changes: 74 additions & 0 deletions src/operator/contrib/reset_arrays.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file reset_arrays.cc
* \brief setting all array element values to zeros
* \author Moises Hernandez-Fernandez, Andrei Ivanov
*/

#include "./reset_arrays-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(ResetArraysParam);

NNVM_REGISTER_OP(reset_arrays)
.describe(R"code(Set to zero multiple arrays
)code" ADD_FILELINE)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
return static_cast<uint32_t>(dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays);
})
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
const uint32_t num_args = dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays;
std::vector<uint32_t> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(i);
}
return ret;
})
.set_num_outputs(0)
.set_attr_parser(ParamParser<ResetArraysParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ResetArraysShape)
.set_attr<nnvm::FInferType>("FInferType", ResetArraysType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const uint32_t num_args = dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays;
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(std::string("array_") + std::to_string(i));
}
return ret;
})
.add_argument("data", "NDArray-or-Symbol[]", "Arrays")
.add_arguments(ResetArraysParam::__FIELDS__());

NNVM_REGISTER_OP(reset_arrays)
.set_attr<FCompute>("FCompute<cpu>", ResetArrays<cpu>);

template<>
void ResetMemory<cpu>(void *pntr, size_t len, mshadow::Stream<cpu> *s) {
memset(pntr, 0, len);
}

} // namespace op
} // namespace mxnet
40 changes: 40 additions & 0 deletions src/operator/contrib/reset_arrays.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file reset_arrays.cu
* \brief setting all array element values to zeros
* \author Moises Hernandez-Fernandez, Andrei Ivanov
*/
#include "./reset_arrays-inl.h"

namespace mxnet {
namespace op {

template<>
void ResetMemory<gpu>(void *pntr, size_t len, mshadow::Stream<gpu> *s) {
CUDA_CALL(cudaMemsetAsync(pntr, 0, len, mshadow::Stream<gpu>::GetStream(s)));
}

NNVM_REGISTER_OP(reset_arrays)
.set_attr<FCompute>("FCompute<gpu>", ResetArrays<gpu>);

} // namespace op
} // namespace mxnet
66 changes: 57 additions & 9 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import warnings
import json
import unittest
import random

@with_seed()
def test_parameter():
Expand Down Expand Up @@ -1504,15 +1505,62 @@ def test_hybrid_multi_context():

@with_seed()
def test_zero_grad():
data = mx.nd.random.uniform(shape=(3,3))
net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_')
net.initialize()
with mx.autograd.record():
l = net(data)
l.backward()
net.collect_params().zero_grad()
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
def _test_grad_reset(ctx, dtype='float32', sparse=False, embeddingType=None):
data = mx.nd.random.uniform(shape=(3,3), dtype=dtype, ctx=ctx)
if embeddingType is None:
embeddingType = dtype
net = nn.Embedding(3, 4, sparse_grad=sparse, prefix='test_zero_grad_', dtype=embeddingType)
net.initialize(ctx=ctx)
with mx.autograd.record():
l = net(data)
l.backward()
net.collect_params().zero_grad()
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)

def _test_multi_reset(nArrays, dtype, ctx):
# Construct the list of non-zeros arrays with random shapes
arr = []
for _ in range(nArrays):
arrType = random.choice(dtype) if isinstance(dtype, list) else dtype
shape = ()
for _ in range(np.random.randint(1, 5)):
shape = shape + (np.random.randint(1, 10),)
arr.append(mx.nd.random.uniform(shape=shape, dtype=arrType, ctx=ctx))

# Reset all arrays
mx.nd.reset_arrays(*arr, num_arrays=len(arr))

# Check results
for i in range(nArrays):
grad = arr[i].asnumpy()
assert_almost_equal(grad, grad * 0)


# Setting context for current test
ctx = mx.context.current_context()

# Launching _test_multi_reset 10 times with different types & randomly chosen nArrays
testedTypes = ['float16', 'float32', 'float64']
for _ in range(10):
for type in [testedTypes] + testedTypes:
_test_multi_reset(np.random.randint(1, 50), type, ctx)

# Saving value of environment variable, if it was defined
envVarKey = 'MXNET_STORAGE_FALLBACK_LOG_VERBOSE'
envVarValue = os.environ[envVarKey] if envVarKey in os.environ else None
# Changing value of environment variable
os.environ[envVarKey] = '0'
for type in ['float16', 'float32', 'float64']:
for embType in ['float32', 'float64']:
for sparse in [True, False]:
_test_grad_reset(ctx, dtype=type, sparse=sparse, embeddingType=embType)

# Remove or restore the value of environment variable
if envVarValue is None:
del os.environ[envVarKey]
else:
os.environ[envVarKey] = envVarValue

def check_hybrid_static_memory(**kwargs):
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
Expand Down

0 comments on commit 82ed82f

Please sign in to comment.