diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index dedd81496c2f..8800684ad0b4 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -24,9 +24,10 @@ 'ParameterDict', 'tensor_types'] -from collections import OrderedDict +from collections import OrderedDict, defaultdict import warnings import numpy as np +import mxnet as mx from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, context @@ -887,8 +888,22 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, def zero_grad(self): """Sets all Parameters' gradient buffer to 0.""" - for i in self.values(): - i.zero_grad() + # collect gradient arrays for each ctx + arrays = defaultdict(list) + for p in self.values(): + if p.grad_req == 'null' or p._grad is None: + continue + for g in p.list_grad(): + if g.stype == 'row_sparse': + mx.ndarray.zeros_like(g, out=g) + else: + arrays[g.context].append(g) + + if len(arrays) == 0: + return + + for arr in arrays.values(): + mx.nd.reset_arrays(*arr, num_arrays=len(arr)) def reset_ctx(self, ctx): """Re-assign all Parameters to other contexts. diff --git a/src/operator/contrib/reset_arrays-inl.h b/src/operator/contrib/reset_arrays-inl.h new file mode 100644 index 000000000000..3559a21879e4 --- /dev/null +++ b/src/operator/contrib/reset_arrays-inl.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file reset_arrays-inl.h + * \brief setting all array element values to zeros + * \author Moises Hernandez-Fernandez, Andrei Ivanov + */ + +#ifndef MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_ +#define MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_ + +#include +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +struct ResetArraysParam : public dmlc::Parameter { + int num_arrays; + DMLC_DECLARE_PARAMETER(ResetArraysParam) { + DMLC_DECLARE_FIELD(num_arrays) + .describe("number of input arrays."); + } +}; + +inline bool ResetArraysShape(const NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const auto& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_shape->size(), param.num_arrays); + for (auto s : *in_shape) { + if (s.ndim() == 0) + return false; + } + + return true; +} + +inline bool ResetArraysType(const NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + const auto& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_type->size(), param.num_arrays); + for (size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) + return false; + } + + return true; +} + +template +void ResetMemory(void *pntr, size_t len, mshadow::Stream *s); + +template +void ResetArrays(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto s = ctx.get_stream(); + const auto& param = nnvm::get(attrs.parsed); + for (int i = 0; i < param.num_arrays; i++) { // array index in inputs + const size_t size = inputs[i].shape_.Size(); + MSHADOW_REAL_TYPE_SWITCH(inputs[i].type_flag_, DType, + ResetMemory(inputs[i].FlatTo2D(s).dptr_, size * sizeof(DType), s); + ) + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_ diff --git a/src/operator/contrib/reset_arrays.cc b/src/operator/contrib/reset_arrays.cc new file mode 100644 index 000000000000..f67e0098cd11 --- /dev/null +++ b/src/operator/contrib/reset_arrays.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file reset_arrays.cc + * \brief setting all array element values to zeros + * \author Moises Hernandez-Fernandez, Andrei Ivanov + */ + +#include "./reset_arrays-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(ResetArraysParam); + +NNVM_REGISTER_OP(reset_arrays) +.describe(R"code(Set to zero multiple arrays +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + return static_cast(dmlc::get(attrs.parsed).num_arrays); + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const uint32_t num_args = dmlc::get(attrs.parsed).num_arrays; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(i); + } + return ret; + }) +.set_num_outputs(0) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ResetArraysShape) +.set_attr("FInferType", ResetArraysType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const uint32_t num_args = dmlc::get(attrs.parsed).num_arrays; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("array_") + std::to_string(i)); + } + return ret; + }) +.add_argument("data", "NDArray-or-Symbol[]", "Arrays") +.add_arguments(ResetArraysParam::__FIELDS__()); + +NNVM_REGISTER_OP(reset_arrays) +.set_attr("FCompute", ResetArrays); + +template<> +void ResetMemory(void *pntr, size_t len, mshadow::Stream *s) { + memset(pntr, 0, len); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/reset_arrays.cu b/src/operator/contrib/reset_arrays.cu new file mode 100644 index 000000000000..f7a9d0034665 --- /dev/null +++ b/src/operator/contrib/reset_arrays.cu @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file reset_arrays.cu + * \brief setting all array element values to zeros + * \author Moises Hernandez-Fernandez, Andrei Ivanov + */ +#include "./reset_arrays-inl.h" + +namespace mxnet { +namespace op { + +template<> +void ResetMemory(void *pntr, size_t len, mshadow::Stream *s) { + CUDA_CALL(cudaMemsetAsync(pntr, 0, len, mshadow::Stream::GetStream(s))); +} + +NNVM_REGISTER_OP(reset_arrays) +.set_attr("FCompute", ResetArrays); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 380ce762a9f7..c6da0f4d9d96 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -33,6 +33,7 @@ import warnings import json import unittest +import random @with_seed() def test_parameter(): @@ -1504,15 +1505,62 @@ def test_hybrid_multi_context(): @with_seed() def test_zero_grad(): - data = mx.nd.random.uniform(shape=(3,3)) - net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_') - net.initialize() - with mx.autograd.record(): - l = net(data) - l.backward() - net.collect_params().zero_grad() - grad = net.collect_params()['test_zero_grad_weight'].grad() - assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) + def _test_grad_reset(ctx, dtype='float32', sparse=False, embeddingType=None): + data = mx.nd.random.uniform(shape=(3,3), dtype=dtype, ctx=ctx) + if embeddingType is None: + embeddingType = dtype + net = nn.Embedding(3, 4, sparse_grad=sparse, prefix='test_zero_grad_', dtype=embeddingType) + net.initialize(ctx=ctx) + with mx.autograd.record(): + l = net(data) + l.backward() + net.collect_params().zero_grad() + grad = net.collect_params()['test_zero_grad_weight'].grad() + assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) + + def _test_multi_reset(nArrays, dtype, ctx): + # Construct the list of non-zeros arrays with random shapes + arr = [] + for _ in range(nArrays): + arrType = random.choice(dtype) if isinstance(dtype, list) else dtype + shape = () + for _ in range(np.random.randint(1, 5)): + shape = shape + (np.random.randint(1, 10),) + arr.append(mx.nd.random.uniform(shape=shape, dtype=arrType, ctx=ctx)) + + # Reset all arrays + mx.nd.reset_arrays(*arr, num_arrays=len(arr)) + + # Check results + for i in range(nArrays): + grad = arr[i].asnumpy() + assert_almost_equal(grad, grad * 0) + + + # Setting context for current test + ctx = mx.context.current_context() + + # Launching _test_multi_reset 10 times with different types & randomly chosen nArrays + testedTypes = ['float16', 'float32', 'float64'] + for _ in range(10): + for type in [testedTypes] + testedTypes: + _test_multi_reset(np.random.randint(1, 50), type, ctx) + + # Saving value of environment variable, if it was defined + envVarKey = 'MXNET_STORAGE_FALLBACK_LOG_VERBOSE' + envVarValue = os.environ[envVarKey] if envVarKey in os.environ else None + # Changing value of environment variable + os.environ[envVarKey] = '0' + for type in ['float16', 'float32', 'float64']: + for embType in ['float32', 'float64']: + for sparse in [True, False]: + _test_grad_reset(ctx, dtype=type, sparse=sparse, embeddingType=embType) + + # Remove or restore the value of environment variable + if envVarValue is None: + del os.environ[envVarKey] + else: + os.environ[envVarKey] = envVarValue def check_hybrid_static_memory(**kwargs): x = mx.nd.random.uniform(shape=(2, 3, 32, 32))