diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 01f76d637a97..1ab86af2b93f 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -71,8 +71,11 @@ class Trainer(object): """ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', compression_params=None, update_on_kvstore=None): + param_list = [] if isinstance(params, (dict, ParameterDict)): - params = list(params.values()) + for key in sorted(list(params.keys())): + param_list.append(params[key]) + params = param_list if not isinstance(params, (list, tuple)): raise ValueError( "First argument must be a list or dict of Parameters, " \ diff --git a/src/executor/simple_partition_pass.h b/src/executor/simple_partition_pass.h index 5b26a4523c13..ea1dcf39b8ba 100644 --- a/src/executor/simple_partition_pass.h +++ b/src/executor/simple_partition_pass.h @@ -102,8 +102,7 @@ class BidirectionalGraph { std::vector> get_subsets(FCompatible is_compatible) { std::vector> subgraphs; std::unordered_set incomp_set; - std::unordered_set all_set(nodes.size()); - std::vector separation_sets; + std::vector> separation_sets; // Check each node for compatibility // and, if it is incompatible, mark nodes // on each side of it as not possible to be @@ -111,48 +110,79 @@ class BidirectionalGraph { for (Node& node : nodes) { if (!is_compatible(node.nnvmptr)) { incomp_set.insert(&node); - std::unordered_set in_graph; - std::unordered_set out_graph; - std::vector dummy_head; - dummy_head.emplace_back(&node); - DFS(dummy_head, false, [&out_graph, &is_compatible](Node* node) { - if (is_compatible(node->nnvmptr)) - out_graph.insert(node); - }); - DFS(dummy_head, true, [&in_graph, is_compatible](Node* node) { - if (is_compatible(node->nnvmptr)) - in_graph.insert(node); - }); - if (!(in_graph.empty() || out_graph.empty())) - separation_sets.push_back(std::make_pair(in_graph, out_graph)); } - all_set.emplace(&node); } - IncompMap incomp_map; - std::unordered_set comp_set; - comp_set.insert(all_set.begin(), all_set.end()); - for (Node* n : incomp_set) { - comp_set.erase(n); + for (Node& node : nodes) { + if (incomp_set.count(&node) != 0) { + // Check if all your inputs are incompatible too. + // If so, then your separation set does not matter, + // because it will covered by the sets of your inputs + bool inside_node = true; + for (Node* input : node.inputs) { + if (incomp_set.count(input) == 0) { + inside_node = false; + } + } + if (!inside_node) { + std::unordered_set in_graph; + std::unordered_set out_graph; + std::vector dummy_head; + dummy_head.emplace_back(&node); + DFS(dummy_head, false, [&out_graph](Node* node) { + out_graph.insert(node); + }); + DFS(dummy_head, true, [&in_graph](Node* node) { + in_graph.insert(node); + }); + separation_sets.push_back(std::make_pair(true, + std::make_pair(in_graph, out_graph))); + } else { + separation_sets.push_back(std::make_pair(false, PairSet())); + } + } else { + separation_sets.push_back(std::make_pair(false, PairSet())); + } } + IncompMap incomp_map; // For each node construct the map of nodes that cannot be in // the same subset - for (Node* n : comp_set) { - for (PairSet p : separation_sets) { - if (p.first.count(n)) { - incomp_map[n].insert(p.second.begin(), p.second.end()); - } else if (p.second.count(n)) { - incomp_map[n].insert(p.first.begin(), p.first.end()); + index_t num_nodes = nodes.size(); + for (index_t i = 0; i < num_nodes; ++i) { + const auto n = &(nodes[i]); + if (incomp_set.count(n) == 0) { + for (index_t j = i + 1; j < num_nodes; ++j) { + const auto& sep_set_pair = separation_sets[j]; + if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) { + const auto& p = sep_set_pair.second; + if (p.first.count(n)) { + incomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + incomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + } + for (index_t j = i - 1; j >= 0; --j) { + const auto& sep_set_pair = separation_sets[j]; + if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) { + const auto& p = sep_set_pair.second; + if (p.first.count(n)) { + incomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + incomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + } + for (Node* incomp_n : incomp_set) { + incomp_map[n].erase(incomp_n); } - } - for (Node* incomp_n : incomp_set) { - incomp_map[n].erase(incomp_n); } } std::unordered_set unused_set; - unused_set.reserve(comp_set.size()); - for (auto& n : comp_set) { - unused_set.insert(n); + for (auto& n : nodes) { + if (incomp_set.count(&n) == 0) { + unused_set.insert(&n); + } } std::unordered_set visited; std::deque stack(outputs.begin(), outputs.end()); diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 24270f210888..8f4b8e32b26e 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -1032,17 +1032,19 @@ OpStatePtr CachedOp::Forward( CHECK_EQ(inputs.size(), num_inputs()); Context default_ctx = inputs[0]->ctx(); - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + { + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); - const auto& idx = state.info.fwd_graph.indexed_graph(); - for (size_t i = 0; i < inputs.size(); ++i) { - CHECK_EQ(inputs[i]->ctx(), default_ctx) - << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name - << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name - << " is on " << inputs[i]->ctx(); + const auto& idx = state.info.fwd_graph.indexed_graph(); + for (size_t i = 0; i < inputs.size(); ++i) { + CHECK_EQ(inputs[i]->ctx(), default_ctx) + << "CachedOp requires all inputs to live on the same context. But " + << idx[idx.input_nodes()[0]].source->attrs.name + << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name + << " is on " << inputs[i]->ctx(); + } } int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); diff --git a/src/operator/contrib/multi_sum_sq-inl.h b/src/operator/contrib/multi_sum_sq-inl.h index 876155215d1c..b8609c0f217f 100644 --- a/src/operator/contrib/multi_sum_sq-inl.h +++ b/src/operator/contrib/multi_sum_sq-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2019 by Contributors * \file multi_l2_norm-inl.h * \brief vectorized L2 norm over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez */ @@ -32,6 +32,10 @@ #include #include "../operator_common.h" +namespace multi_sum_sq { +enum MultiSumSqUpdateResource {kTempSpace}; +} // namespace multi_sum_sq + namespace mxnet { namespace op { @@ -80,7 +84,7 @@ inline bool MultiSumSqType(const NodeAttrs& attrs, template void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, mshadow::Stream *s); + float *out_ptr, const OpContext &ctx); template void MultiSumSq(const nnvm::NodeAttrs& attrs, @@ -91,7 +95,7 @@ void MultiSumSq(const nnvm::NodeAttrs& attrs, auto s = ctx.get_stream(); const auto& p = dmlc::get(attrs.parsed); float* out_ptr = outputs[0].FlatTo2D(s).dptr_; - MultiSumSqRun(inputs, p.num_arrays, out_ptr, s); + MultiSumSqRun(inputs, p.num_arrays, out_ptr, ctx); } } // namespace op diff --git a/src/operator/contrib/multi_sum_sq.cc b/src/operator/contrib/multi_sum_sq.cc index cdb5423db23f..16c99d1c9699 100644 --- a/src/operator/contrib/multi_sum_sq.cc +++ b/src/operator/contrib/multi_sum_sq.cc @@ -21,7 +21,7 @@ * Copyright (c) 2019 by Contributors * \file multi_sum_sq.cc * \brief vectorized sum or squared over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez */ #include "./multi_sum_sq-inl.h" @@ -52,20 +52,24 @@ NNVM_REGISTER_OP(multi_sum_sq) return ret; }) .set_attr("FCompute", MultiSumSq) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .add_argument("data", "NDArray-or-Symbol[]", "Arrays") .add_arguments(MultiSumSqParam::__FIELDS__()); template -inline void CalcSumSq(const std::vector &inputs, int nInputs, +inline void CalcSumSq(const std::vector &inputs, int n_inputs, float *out_ptr, mshadow::Stream *s) { int i; size_t j; #pragma omp parallel for private(i, j) - for (i = 0; i < nInputs; ++i) { // array index in inputs + for (i = 0; i < n_inputs; ++i) { // array index in inputs float sum = 0; const auto address = inputs[i].FlatTo2D(s).dptr_; - const auto jMax = inputs[i].shape_.Size(); - for (j = 0; j < jMax; ++j) + const auto j_max = inputs[i].shape_.Size(); + for (j = 0; j < j_max; ++j) sum += address[j] * address[j]; out_ptr[i] = sum; @@ -73,10 +77,10 @@ inline void CalcSumSq(const std::vector &inputs, int nInputs, } template<> -void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, mshadow::Stream *s) { +void MultiSumSqRun(const std::vector &inputs, int n_inputs, + float *out_ptr, const OpContext &ctx) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, - CalcSumSq(inputs, nInputs, out_ptr, s); + CalcSumSq(inputs, n_inputs, out_ptr, ctx.get_stream()); ) } diff --git a/src/operator/contrib/multi_sum_sq.cu b/src/operator/contrib/multi_sum_sq.cu index 6f6fe56bfd81..620c9ca8a073 100644 --- a/src/operator/contrib/multi_sum_sq.cu +++ b/src/operator/contrib/multi_sum_sq.cu @@ -21,7 +21,7 @@ * Copyright (c) 2019 by Contributors * \file multi_sum_sq.cu * \brief vectorized sums of squares norm over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez */ #include "./multi_sum_sq-inl.h" #include @@ -43,96 +43,121 @@ struct MultiSumSqKernelParam { int sizes[ARRAY_LIMIT]; unsigned char block_to_tensor[BLOCK_LIMIT]; int block_to_chunk[BLOCK_LIMIT]; + int max_chunks_per_tensor = -1; }; template -__device__ __forceinline__ DType reduce_block_into_lanes(DType* x, - DType val, - int lanes = 1, - bool share_result = false) { - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { +__device__ __forceinline__ DType ReduceBlockIntoLanes(DType* x, + DType val) { + int tid = threadIdx.x; + int block_size = blockDim.x; + + if (block_size >= 64) { x[tid] = val; __syncthreads(); } #pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + for (int i = (block_size >> 1); i >= 64; i >>= 1) { if (tid < i) x[tid] = x[tid] + x[tid+i]; __syncthreads(); } DType final; - if (tid < 32) { - if (blockSize >= 64) + if (block_size >= 64) final = x[tid] + x[tid+32]; else final = val; - // __SYNCWARP(); #pragma unroll - for (int i = 16; i >= lanes; i >>= 1) + for (int i = 16; i >= 1; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); } - - if (share_result) { - if (tid < lanes) - x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - return final; } template __global__ void MultiSumSqKernel(int chunk_size, MultiSumSqKernelParam param, - float* output) { + float* block_reductions, + int start_tensor_id) { const int tensor_loc = param.block_to_tensor[blockIdx.x]; const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size; const int n = param.sizes[tensor_loc] - chunk_len; const DType* x = param.addresses[tensor_loc] + chunk_len; - const auto iMax = n <= chunk_size? n : chunk_size; + const auto i_max = n <= chunk_size ? n : chunk_size; __shared__ float vals[512]; // Non-divergent exit condition for __syncthreads, not necessary here float val = 0; for (int i_start = 0; - i_start < iMax; + i_start < i_max; i_start += blockDim.x * ILP) { int i = i_start + threadIdx.x; - // #pragma unroll - for (int ii = 0; ii < ILP && i < iMax; ++ii, i += blockDim.x) { +#pragma unroll + for (int ii = 0; ii < ILP && i < i_max; ++ii, i += blockDim.x) { const auto incoming_val = static_cast(x[i]); val += incoming_val * incoming_val; } } + const float final = ReduceBlockIntoLanes(vals, val); + + if (threadIdx.x == 0) { + block_reductions[(start_tensor_id + tensor_loc) * param.max_chunks_per_tensor + + param.block_to_chunk[blockIdx.x]] = final; + } +} + +template +__global__ void GlobalReductionKernel(MultiSumSqKernelParam param, + float* block_reductions, + float* output) { + __shared__ float vals[512]; + float* reductions_this_tensor = block_reductions + blockIdx.x * param.max_chunks_per_tensor; + float val = 0; + for (int i = threadIdx.x; i < param.max_chunks_per_tensor; i += blockDim.x) + val += reductions_this_tensor[i]; + + float final = ReduceBlockIntoLanes(vals, val); - const float final = reduce_block_into_lanes(vals, val); if (threadIdx.x == 0) - atomicAdd(output + tensor_loc, final); + output[blockIdx.x] = final; } template<> -void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, mshadow::Stream *s) { +void MultiSumSqRun(const std::vector &inputs, int n_inputs, + float *out_ptr, const OpContext &ctx) { const int chunk_size = 32768; const int block_size = 512; using namespace mxnet_op; + auto s = ctx.get_stream(); auto stream = mshadow::Stream::GetStream(s); - CUDA_CALL(cudaMemsetAsync(out_ptr, 0, nInputs * sizeof(float), stream)); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { MultiSumSqKernelParam param; + // find max num of chunks in tensors + for (int t = 0; t < n_inputs; t++) { + int chunks_this_tensor = (inputs[t].shape_.Size() + chunk_size - 1) / chunk_size; + if (chunks_this_tensor > param.max_chunks_per_tensor) + param.max_chunks_per_tensor = chunks_this_tensor; + } + // temporary storage for the reduction of each block + size_t workspace_size = n_inputs * param.max_chunks_per_tensor * sizeof(float); + Tensor workspace = + ctx.requested[multi_sum_sq::kTempSpace].get_space_typed( + Shape1(workspace_size), s); + Tensor block_reductions(reinterpret_cast(&workspace[0]), + Shape1(n_inputs * param.max_chunks_per_tensor), s); + CUDA_CALL(cudaMemsetAsync(block_reductions.dptr_, 0, + n_inputs * param.max_chunks_per_tensor* sizeof(float), + stream)); + int loc_block_info = 0; // position in param.block_to_tensor and param.block_to_chunck int loc_tensor_info = 0; // position in param.sizes and param.addresses - int output_offset = 0; // array index of the first block pointed on by param.addresses - for (int t = 0; t < nInputs; t++, loc_tensor_info++) { // array index in inputs + int start_tensor_id = 0; + for (int t = 0; t < n_inputs; t++, loc_tensor_info++) { // array index in inputs param.sizes[loc_tensor_info] = inputs[t].shape_.Size(); param.addresses[loc_tensor_info] = inputs[t].FlatTo2D(s).dptr_; const int chunks_this_tensor = (inputs[t].shape_.Size() - 1) / chunk_size; @@ -142,27 +167,30 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, loc_block_info++; const bool last_curr_chunk = chunk == chunks_this_tensor; - const bool tensors_full = last_curr_chunk && loc_tensor_info == 109; - const bool blocks_full = (loc_block_info == 320); - const bool last_chunk = last_curr_chunk && t == nInputs - 1; + const bool tensors_full = last_curr_chunk && loc_tensor_info == (ARRAY_LIMIT-1); + const bool blocks_full = (loc_block_info == BLOCK_LIMIT); + const bool last_chunk = last_curr_chunk && t == n_inputs - 1; if (!(tensors_full || blocks_full || last_chunk)) continue; - MultiSumSqKernel<<>> - (chunk_size, param, out_ptr + output_offset); + (chunk_size, param, block_reductions.dptr_, start_tensor_id); MSHADOW_CUDA_POST_KERNEL_CHECK(MultiSumSqKernel); + loc_block_info = 0; if (last_curr_chunk) { // if you start from a new tensor loc_tensor_info = -1; - output_offset = t + 1; + start_tensor_id = t + 1; } else { // if you start from the same tensor param.sizes[0] = param.sizes[loc_tensor_info]; param.addresses[0] = param.addresses[loc_tensor_info]; loc_tensor_info = 0; - output_offset = t; + start_tensor_id = t; } } } + // Global reduction + GlobalReductionKernel<<>> + (param, block_reductions.dptr_, out_ptr); }); } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 405c4d946df7..b219011831bd 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -271,6 +271,36 @@ def test_fft(): def _make_ndarrays(input_list, ctx=mx.gpu(0)): return [mx.nd.array(arr, dtype=arr.dtype, ctx=ctx) for arr in input_list] +def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): + values_arr = [np.random.rand(*shape).astype(dtype) * 10. for shape in shapes] + mx_vals = _make_ndarrays(values_arr, ctx=ctx) + sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + sum_sq2 = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + # checks that operator is deterministic + assert np.array_equal(sum_sq.asnumpy(), sum_sq2.asnumpy()) + + ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr], + dtype='float32', ctx=ctx) + assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1) + +@with_seed() +def test_multi_sum_sq(): + min_nparam = 100 + max_nparam = 120 + min_dim = 50000 + max_dim = 100000 + max_ndim = 1 + + dtypes = ['float16','float32', 'float64'] + for ctx in [mx.gpu(0)]: + for dtype in dtypes: + nparam = np.random.randint(min_nparam + 1, max_nparam + 1) + shapes = [np.random.randint(min_dim, max_dim + 1, size=max_ndim) for i in range(nparam)] + low_tol = ctx == mx.cpu(0) and ('float16'in [dtype]) + tol1 = 1e-3 if low_tol else 1e-5 + tol2 = 1e-6 if low_tol else 1e-7 + check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2) + def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2): weights_arr = [np.random.rand(*shape).astype(w_dtype) * 10. for shape in shapes] grads_arr = [np.random.rand(*shape).astype(g_dtype) for shape in shapes] diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 2d5874a8b97b..9f02733d0a25 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -291,3 +291,19 @@ def test_trainer_lr_sched(): assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i) lr *= factor mx.nd.waitall() + +@with_seed() +def test_gluon_trainer_param_order(): + net = mx.gluon.nn.Sequential() + # layers may be added in a random order for all workers + layers = {'ones_': 1, 'zeros_': 0} + for name, init in layers.items(): + net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init), + use_bias=False, prefix=name)) + params = net.collect_params() + net.initialize() + trainer = gluon.Trainer(params, 'sgd') + for name, init in layers.items(): + expected_idx = 0 if name == 'ones_' else 1 + expected_name = name + 'weight' + assert trainer._params[expected_idx].name == expected_name