From ff759f20ef223ab922a2f9aa3cbfaf127c802098 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Tue, 25 Sep 2018 17:15:38 -0700 Subject: [PATCH 01/10] Dual stream conv backward(). Enable with MXNET_GPU_WORKER_NSTREAMS=2. --- include/mxnet/base.h | 28 ++++++ include/mxnet/op_attr_types.h | 4 + src/engine/naive_engine.cc | 4 +- src/engine/stream_manager.h | 9 +- src/engine/threaded_engine_perdevice.cc | 8 +- src/operator/nn/cudnn/cudnn_convolution-inl.h | 93 +++++++++++++++---- tests/python/gpu/test_operator_gpu.py | 44 +++++++++ 7 files changed, 168 insertions(+), 22 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index f88b22784a1b..c54b3cc97bff 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -193,6 +193,11 @@ struct Context { * \return The number of GPUs that are available. */ inline static int32_t GetGPUCount(); + /*! + * Get the number of streams that a GPU Worker has available to operations. + * \return The number of streams that are available. + */ + inline static int32_t GetGPUStreamsPerWorker(); /*! * \brief get the free and total available memory on a GPU * \param dev the GPU number to query @@ -232,6 +237,10 @@ struct RunContext { * \brief the stream of the device, can be NULL or Stream* in GPU mode */ void *stream; + /*! + * \brief the auxiliary stream of the device, can be NULL or Stream* in GPU mode + */ + void *aux_stream; /*! * \brief get mshadow stream from Context * \return the mshadow stream @@ -241,6 +250,16 @@ struct RunContext { inline mshadow::Stream* get_stream() const { return static_cast*>(stream); } + /*! + * \brief get the auxiliary (i.e. 2nd) mshadow stream from Context + * The user must sync work enqueued to this stream with the primary stream using events. + * \return the mshadow stream + * \tparam xpu the device type of the stream + */ + template + inline mshadow::Stream* get_aux_stream() const { + return static_cast*>(aux_stream); + } /*! \brief get the base Context from RunContext */ inline const Context& get_ctx() const { return ctx; @@ -305,6 +324,15 @@ inline int32_t Context::GetGPUCount() { #endif } +inline int32_t Context::GetGPUStreamsPerWorker() { + // The default number of streams available if the user has not set MXNET_GPU_WORKER_NSTREAMS. + const int32_t default_num_streams = 1; + // The get_aux_stream() interface can supply one additional stream beyond the standard one. + static int32_t num_streams = + dmlc::GetEnv("MXNET_GPU_WORKER_NSTREAMS", default_num_streams) >= 2 ? 2 : 1; + return num_streams; +} + inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem, uint64_t *total_mem) { #if MXNET_USE_CUDA diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 41be554953fd..39c9783977e8 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -83,6 +83,10 @@ struct OpContext { inline mshadow::Stream* get_stream() const { return run_ctx.get_stream(); } + template + inline mshadow::Stream* get_aux_stream() const { + return run_ctx.get_aux_stream(); + } }; /*! \brief the execution type of the operator */ diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 05b72d2a6fde..bef71a8314b3 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -173,12 +173,12 @@ class NaiveEngine final : public Engine { if (streams_[dev_id] == nullptr) { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); } - exec_fun(RunContext{exec_ctx, streams_[dev_id]}, callback); + exec_fun(RunContext{exec_ctx, streams_[dev_id], streams_[dev_id]}, callback); #else LOG(FATAL) << "GPU is not enabled"; #endif } else { - exec_fun(RunContext{exec_ctx, &cpu_stream_}, callback); + exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr}, callback); } CHECK(this->req_completed_) << "NaiveEngine only support synchronize Push so far"; diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 516e04bf5e82..88fd28849d4f 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -55,6 +55,8 @@ class StreamManager { #if MXNET_USE_CUDA std::array*, kStreams>, kNumGpus> gpu_streams_; + std::array*, kStreams>, kNumGpus> + gpu_aux_streams_; std::array*, kNumGpus> gpu_io_streams_; std::array gpu_cnt_; #endif // MXNET_USE_CUDA @@ -80,12 +82,17 @@ RunContext StreamManager::GetRunContext( for (auto&& i : gpu_streams_.at(ctx.dev_id)) { i = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); } + for (auto&& i : gpu_aux_streams_.at(ctx.dev_id)) { + i = Context::GetGPUStreamsPerWorker() >= 2 ? + mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id) : nullptr; + } counter = 0; } use_counter = counter; counter = (counter + 1) % kStreams; } - ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter)}; + ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter), + gpu_aux_streams_.at(ctx.dev_id).at(use_counter)}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index b6537dabb638..149430fc6f2a 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -244,7 +244,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { this->is_worker_ = true; #if MXNET_USE_CUDA CHECK(block != nullptr); - mshadow::Stream *stream; + mshadow::Stream *stream = nullptr; + mshadow::Stream *aux_stream = nullptr; do { ThreadPool::SetReadyOnDestroy setReady(ready_event); // allocate stream @@ -253,11 +254,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { stream = mshadow::NewStream(false, false, ctx.dev_id); } else { stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); + if (Context::GetGPUStreamsPerWorker() >= 2) { + aux_stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); + } } } while (false); // execute task OprBlock* opr_block; - RunContext run_ctx{ctx, stream}; + RunContext run_ctx{ctx, stream, aux_stream}; auto* task_queue = &(block->task_queue); // Don't eat up omp threads for GPU jobs. They're probably best used elsewhere, diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 3bd6c5a3826b..2e31d7d0404d 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -53,6 +53,11 @@ class CuDNNConvolutionOp { CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_)); CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_)); CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_w_)); + parallelize_backward_kernels_ = Context::GetGPUStreamsPerWorker() >= 2; + if (parallelize_backward_kernels_) { + CUDA_CALL(cudaEventCreateWithFlags(&dgrad_can_start_, cudaEventDisableTiming)); + CUDA_CALL(cudaEventCreateWithFlags(&dgrad_completion_, cudaEventDisableTiming)); + } } void Init(const ConvolutionParam& param, @@ -110,6 +115,7 @@ class CuDNNConvolutionOp { // future cuDNN releases. SelectAlgo(rctx, in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); + GetTempSize(rctx); } ~CuDNNConvolutionOp() { @@ -120,6 +126,10 @@ class CuDNNConvolutionOp { CUDNN_CALL(cudnnDestroyConvolutionDescriptor(forward_conv_desc_)); CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_)); CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_w_)); + if (parallelize_backward_kernels_) { + CUDA_CALL(cudaEventDestroy(dgrad_can_start_)); + CUDA_CALL(cudaEventDestroy(dgrad_completion_)); + } } void Forward(const OpContext &ctx, @@ -131,7 +141,6 @@ class CuDNNConvolutionOp { CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); - GetTempSize(ctx); Tensor workspace = AllocateTempWorkspace(ctx, forward_workspace_byte_); size_t workspace_size = TensorSizeBytes(workspace); @@ -224,6 +233,14 @@ class CuDNNConvolutionOp { CHECK_EQ(in_data.size(), expected); CHECK_EQ(in_grad.size(), expected); Stream *s = ctx.get_stream(); + Stream *s_dgrad = parallelize_backward_kernels_ ? ctx.get_aux_stream() : s; + + // Make sure the dgrad kernel in the aux stream doesn't start before it would have + // had it been launched into the operator's primary stream. + if (parallelize_backward_kernels_ && req[conv::kData] != kNullOp) { + CUDA_CALL(cudaEventRecord(dgrad_can_start_, s->stream_)); + CUDA_CALL(cudaStreamWaitEvent(s_dgrad->stream_, dgrad_can_start_, 0)); + } // I/O's should have 2 more dims than the kernel dim DType *grad_ptr = GetNdPtr(out_grad[conv::kOut], param_.kernel.ndim() + 2, s); @@ -232,9 +249,27 @@ class CuDNNConvolutionOp { DType *data_ptr = GetNdPtr(in_data[conv::kData], param_.kernel.ndim() + 2, s); DType *gdata_ptr = GetNdPtr(in_grad[conv::kData], param_.kernel.ndim() + 2, s); - GetTempSize(ctx); - Tensor workspace = AllocateTempWorkspace(ctx, backward_workspace_byte_); + size_t backward_workspace_byte = + parallelize_backward_kernels_ ? back_workspace_byte_dgrad_ + back_workspace_byte_wgrad_ + : std::max(back_workspace_byte_dgrad_, + back_workspace_byte_wgrad_); + Tensor workspace = AllocateTempWorkspace(ctx, backward_workspace_byte); size_t workspace_size = TensorSizeBytes(workspace); + DType *workspace_dptr_wgrad = workspace.dptr_; + DType *workspace_dptr_dgrad = workspace.dptr_; + if (parallelize_backward_kernels_) { + CHECK_LE(back_workspace_byte_dgrad_ + back_workspace_byte_wgrad_, workspace_size); + // Large allocations at some point will be given their own page. Pass this alignment on to + // the larger of the two separate dgrad/wgrad workspaces. This probably doesn't matter, but + // corresponds more closely to the workspace alignments used during cudnnFind. + if (back_workspace_byte_dgrad_ > back_workspace_byte_wgrad_) + workspace_dptr_wgrad = workspace.dptr_ + back_workspace_byte_dgrad_ / sizeof(DType); + else + workspace_dptr_dgrad = workspace.dptr_ + back_workspace_byte_wgrad_ / sizeof(DType); + } else { + CHECK_LE(back_workspace_byte_dgrad_, workspace_size); + CHECK_LE(back_workspace_byte_wgrad_, workspace_size); + } #if CUDNN_MAJOR >= 7 typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; @@ -259,14 +294,14 @@ class CuDNNConvolutionOp { grad_ptr, back_conv_desc_w_, back_algo_w_.AlgoNumber(), - workspace.dptr_, - workspace_size, + workspace_dptr_wgrad, + back_workspace_byte_wgrad_, req[conv::kWeight] == kAddTo? &beta_add : &beta, filter_desc_, gwmat_ptr)); } if (req[conv::kData] != kNullOp) { - CUDNN_CALL(cudnnConvolutionBackwardData(s->dnn_handle_, + CUDNN_CALL(cudnnConvolutionBackwardData(s_dgrad->dnn_handle_, &alpha, filter_desc_, wmat_ptr, @@ -274,11 +309,15 @@ class CuDNNConvolutionOp { grad_ptr, back_conv_desc_, back_algo_.AlgoNumber(), - workspace.dptr_, - workspace_size, + workspace_dptr_dgrad, + back_workspace_byte_dgrad_, req[conv::kData] == kAddTo? &beta_add : &beta, in_desc_, gdata_ptr)); + if (parallelize_backward_kernels_) { + CUDA_CALL(cudaEventRecord(dgrad_completion_, s_dgrad->stream_)); + CUDA_CALL(cudaStreamWaitEvent(s->stream_, dgrad_completion_, 0)) + } } #else for (uint32_t g = 0; g < param_.num_group; ++g) { @@ -912,24 +951,30 @@ class CuDNNConvolutionOp { } - void GetTempSize(const OpContext& ctx) { - mshadow::Stream *s = ctx.get_stream(); - size_t back_size = 0, back_size_w = 0; + void GetTempSize(const RunContext& rctx) { + mshadow::Stream *s = rctx.get_stream(); CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, filter_desc_, out_desc_, back_conv_desc_, in_desc_, back_algo_.AlgoNumber(), - &back_size)); + &back_workspace_byte_dgrad_)); CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, in_desc_, out_desc_, back_conv_desc_w_, filter_desc_, back_algo_w_.AlgoNumber(), - &back_size_w)); - backward_workspace_byte_ = std::max(back_size, back_size_w); + &back_workspace_byte_wgrad_)); + // cudaMalloc returns addresses that are aligned for large accesses (e.g. to 512 bytes). + // Since we only make one allocation and divide it into two parts when we parallelize + // the dgrad and wgrad kernels, we round the sizes up to this alignment size so the + // dptrs respect this alignment, even if the separate areas are stacked. + const size_t dptr_alignment = 512; + back_workspace_byte_dgrad_ = RoundToMultiple(back_workspace_byte_dgrad_, dptr_alignment); + back_workspace_byte_wgrad_ = RoundToMultiple(back_workspace_byte_wgrad_, dptr_alignment); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, in_desc_, filter_desc_, @@ -983,11 +1028,17 @@ class CuDNNConvolutionOp { CastTShapeToIntPtr(param_.pad, ¶m_pad_); } + // Round a value 'x' up to the next multiple of 'multiple' + size_t RoundToMultiple(size_t x, size_t multiple) { + size_t retVal = ((x + multiple - 1) / multiple) * multiple; + return retVal; + } + // Allocates a 1D Tensor of words with size in bytes >= `size_bytes`. // Always allocates at least one word. mshadow::Tensor AllocateTempWorkspace(const OpContext &ctx, size_t size_bytes) { mshadow::Stream *s = ctx.get_stream(); - size_t size_words = size_bytes / sizeof(DType) + 1; + size_t size_words = std::max(1UL, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType)); return ctx.requested[conv::kTempSpace].get_space_typed( mshadow::Shape1(size_words), s); } @@ -1035,8 +1086,10 @@ class CuDNNConvolutionOp { // Temp workspace size in bytes needed for Forward() operation. size_t forward_workspace_byte_; - // Temp workspace size in bytes needed for Backward() operation. - size_t backward_workspace_byte_; + // Temp workspace size in bytes needed for Backward() dgrad (data gradient) operation. + size_t back_workspace_byte_dgrad_; + // Temp workspace size in bytes needed for Backward() wgrad (weight gradient) operation. + size_t back_workspace_byte_wgrad_; size_t data_offset_; size_t out_offset_; size_t weight_offset_; @@ -1052,6 +1105,12 @@ class CuDNNConvolutionOp { cudnnConvolutionDescriptor_t back_conv_desc_; // Convolution descriptor for back-prop operations to the weights cudnnConvolutionDescriptor_t back_conv_desc_w_; + // Should dgrad and wgrad be launched into separate streams + bool parallelize_backward_kernels_; + // Event to signal dgrad kernel aux stream completion back to the main stream of this operator. + cudaEvent_t dgrad_completion_; + // Event from the main stream of this operator that the dgrad kernel can begin in the aux stream. + cudaEvent_t dgrad_can_start_; // Algorithm for the forward inference operation CuDNNAlgo forward_algo_; // Algorithm for the back-prop operation to the data diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 7a7c6f69dd77..2d0be2fc6696 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -520,6 +520,50 @@ def test_convolution_options(): sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(1,1,1), pad=(0,0,0), cudnn_off=True, name='conv') check_consistency_NxM([sym, sym_no_cudnn], ctx_list) + +# Helper function to run tests in a subprocess to avoid save/restore of os.environ. +# Also avoids issues of cached environment variable lookups in the backend. +def _test_in_separate_process(func, *args): + try: + mpctx = mp.get_context('spawn') + except: + print('SKIP: python%s.%s lacks the required process fork-exec support ... ' % + sys.version_info[0:2], file=sys.stderr, end='') + else: + seed = np.random.randint(0,1024*1024*1024) + # Prepend seed as first arg + p = mpctx.Process(target=func, args=(seed,)+args) + p.start() + p.join() + assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) + +def _conv_with_num_streams(seed, num_streams): + os.environ['MXNET_GPU_WORKER_NSTREAMS'] = str(num_streams) + with random_seed(seed): + num_trials = 10 + for _ in range(num_trials): + size = np.random.randint(32, 512) + print('size = {}'.format(size)) + # The cudnn conv operator runs dgrad and wgrad in separate streams if enabled, with possible + # kernel overlap. The non-cudnn conv op doesn't do this so is used as the 'golden copy'. + ctx = {'ctx': mx.gpu(0), 'conv_data': (2, 2, size, size), + 'type_dict': {'conv_data': np.float32}} + ctx = {'ctx': mx.gpu(0), 'conv_data': (2, 2, size, size), + 'type_dict': {'conv_data': np.float32}} + # Adding 'flip' here isolates the model from the input node (which can't use inplace store) + flipped = mx.sym.flip(axis=0, name='conv') + sym = mx.sym.Convolution(data=flipped, num_filter=3, kernel=(3,3), pad=(1,1), name='conv') + flipped_no_cudnn = mx.sym.flip(axis=0, name='conv') + sym_no_cudnn = mx.sym.Convolution(data = flipped_no_cudnn, num_filter=3, kernel=(3,3), pad=(1,1), + cudnn_off=True, name='conv') + check_consistency([sym, sym_no_cudnn], [ctx, ctx]) + +@with_seed() +def test_convolution_multiple_streams(): + _test_in_separate_process(_conv_with_num_streams, 1) + _test_in_separate_process(_conv_with_num_streams, 2) + + # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c. # Algos returned by find() can fail to run with grad_req='add' (wgrad kernel beta parameter == 1.0f). @with_seed() From d52774db673db836463cd3e4f790b79d6f4de90e Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 28 Jan 2019 16:35:41 -0800 Subject: [PATCH 02/10] Fix for MSVC compiler. --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 2e31d7d0404d..f06840ce2748 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -1038,7 +1038,7 @@ class CuDNNConvolutionOp { // Always allocates at least one word. mshadow::Tensor AllocateTempWorkspace(const OpContext &ctx, size_t size_bytes) { mshadow::Stream *s = ctx.get_stream(); - size_t size_words = std::max(1UL, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType)); + size_t size_words = std::max(1, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType)); return ctx.requested[conv::kTempSpace].get_space_typed( mshadow::Shape1(size_words), s); } From 1cf5c67370b32d8e41b9669f03071dd8560383f6 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 28 Jan 2019 16:54:56 -0800 Subject: [PATCH 03/10] Fix cpplint. --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index f06840ce2748..533e47f2458a 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -1038,7 +1038,8 @@ class CuDNNConvolutionOp { // Always allocates at least one word. mshadow::Tensor AllocateTempWorkspace(const OpContext &ctx, size_t size_bytes) { mshadow::Stream *s = ctx.get_stream(); - size_t size_words = std::max(1, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType)); + size_t size_words = + std::max(1, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType)); return ctx.requested[conv::kTempSpace].get_space_typed( mshadow::Shape1(size_words), s); } From e5b57a8e0c44dda3e276e5818c7adafe9a36aadf Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 28 Jan 2019 19:10:40 -0800 Subject: [PATCH 04/10] Add MXNET_GPU_WORKER_NSTREAMS env var documentation. --- docs/faq/env_var.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 99ebae21d61f..9c6f2d3f4c0c 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -174,6 +174,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca ## Other Environment Variables +* MXNET_GPU_WORKER_NSTREAMS + - Values: 1, or 2 ```(default=1)``` + - Determines the number of GPU streams available to operators for their functions. + - Setting this to 2 may yield a modest performance increase, since ops like the cuDNN convolution op can then calculate their data- and weight-gradients in parallel. + - Setting this to 2 may also increase a model's demand for GPU global memory. + * MXNET_CUDNN_AUTOTUNE_DEFAULT - Values: 0, 1, or 2 ```(default=1)``` - The default value of cudnn auto tuning for convolution layers. From d16e85fc13f1254529cda875e3150686be79fedf Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Tue, 29 Jan 2019 19:33:21 -0800 Subject: [PATCH 05/10] Improve test function and commenting. --- tests/python/gpu/test_operator_gpu.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 2d0be2fc6696..237df55652d3 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -540,23 +540,26 @@ def _test_in_separate_process(func, *args): def _conv_with_num_streams(seed, num_streams): os.environ['MXNET_GPU_WORKER_NSTREAMS'] = str(num_streams) with random_seed(seed): - num_trials = 10 + # Try to expose timing-dependent improper workspace sharing by parallel dgrad and wgrad + num_trials = 20 for _ in range(num_trials): - size = np.random.randint(32, 512) - print('size = {}'.format(size)) + size = np.random.randint(32, 128) # The cudnn conv operator runs dgrad and wgrad in separate streams if enabled, with possible # kernel overlap. The non-cudnn conv op doesn't do this so is used as the 'golden copy'. - ctx = {'ctx': mx.gpu(0), 'conv_data': (2, 2, size, size), - 'type_dict': {'conv_data': np.float32}} ctx = {'ctx': mx.gpu(0), 'conv_data': (2, 2, size, size), 'type_dict': {'conv_data': np.float32}} # Adding 'flip' here isolates the model from the input node (which can't use inplace store) flipped = mx.sym.flip(axis=0, name='conv') sym = mx.sym.Convolution(data=flipped, num_filter=3, kernel=(3,3), pad=(1,1), name='conv') flipped_no_cudnn = mx.sym.flip(axis=0, name='conv') - sym_no_cudnn = mx.sym.Convolution(data = flipped_no_cudnn, num_filter=3, kernel=(3,3), pad=(1,1), + sym_no_cudnn = mx.sym.Convolution(data=flipped_no_cudnn, num_filter=3, kernel=(3,3), pad=(1,1), cudnn_off=True, name='conv') - check_consistency([sym, sym_no_cudnn], [ctx, ctx]) + try: + # tol can be pretty high- we're looking for a large diff due to garbaged workspace + check_consistency([sym, sym_no_cudnn], [ctx, ctx], tol=1e-2) + except: + print('Failing conv size = {}'.format(size)) + raise @with_seed() def test_convolution_multiple_streams(): From 838bda08a30b567dbd7eb2f7866bcea39c17bc75 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Thu, 31 Jan 2019 14:46:52 -0800 Subject: [PATCH 06/10] Add description of proper aux stream use using events. --- include/mxnet/base.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 7e3ac5b9ec72..c07e43aa408b 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -255,8 +255,14 @@ struct RunContext { return static_cast*>(stream); } /*! - * \brief get the auxiliary (i.e. 2nd) mshadow stream from Context - * The user must sync work enqueued to this stream with the primary stream using events. + * \brief get the auxiliary (i.e. 2nd) mshadow stream from the Context + * The user must sync work enqueued to the aux stream with the primary stream using events. + * For example, gpu kernels that produce the inputs to the operation are likely to be + * launched into the primary stream. An event must be used to synchronize a kernel launched + * into the auxilary stream on the availability of these inputs. Also, a second event must be + * used to synchronize auxiliary stream kernel outputs with the primary stream, so that a + * subsequent operation's kernels launched into the primary stream read their inputs only when + * they are ready. See ./src/operator/nn/cudnn/cudnn_convolution-inl.h for an example. * \return the mshadow stream * \tparam xpu the device type of the stream */ From c3f332055f129a0f7f946adf8f4d21eeb7642e91 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Fri, 8 Feb 2019 19:23:33 -0800 Subject: [PATCH 07/10] RAII rework to simplify usage within operators. --- include/mxnet/base.h | 124 ++++++++++++++++-- include/mxnet/op_attr_types.h | 11 +- src/engine/naive_engine.cc | 13 +- src/engine/stream_manager.h | 21 +-- src/engine/threaded_engine_perdevice.cc | 8 +- src/operator/nn/cudnn/cudnn_convolution-inl.h | 28 +--- 6 files changed, 150 insertions(+), 55 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index c07e43aa408b..908e829742fa 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -226,6 +226,111 @@ struct Context { inline static Context FromString(const std::string& str); }; +#if MXNET_USE_CUDA +/*! \brief Holds an auxiliary mshadow gpu stream that can be synced with a primary stream. */ +class GPUAuxStream { +public: + /*! + * \brief constructor. + * \param primary_stream gpu stream that is synced with the created auxiliary stream. + */ + explicit GPUAuxStream(mshadow::Stream *primary_stream) : + primary_stream_(primary_stream), + aux_stream_(primary_stream), + gpu_stream_sync_event_(nullptr) { + if (Context::GetGPUStreamsPerWorker() >= 2) { + // Create auxiliary stream on the same device with the same properties as the primary stream + bool primary_has_blas_handle = + primary_stream->blas_handle_ownership_ == mshadow::Stream::OwnHandle; + bool primary_has_dnn_handle = + primary_stream->dnn_handle_ownership_ == mshadow::Stream::OwnHandle; + aux_stream_ = mshadow::NewStream(primary_has_blas_handle, + primary_has_dnn_handle, + primary_stream->dev_id); + MSHADOW_CUDA_CALL(cudaEventCreateWithFlags(&gpu_stream_sync_event_, cudaEventDisableTiming)); + } + } + /*! \brief destructor */ + ~GPUAuxStream() { + // If the aux_stream_ == primary_stream_, then we created no new streams to destroy. + if (aux_stream_ != primary_stream_) { + MSHADOW_CATCH_ERROR(mshadow::DeleteStream(aux_stream_)); + MSHADOW_CATCH_ERROR(cudaEventDestroy(gpu_stream_sync_event_)); + } + } + /*! + * \brief Makes future aux stream work wait on the completion of existing primary stream work. + */ + void PreAuxStreamUseSync() { + // If the aux_stream_ == primary_stream_, then no synchronization is necessary. + if (aux_stream_ != primary_stream_) + StreamSync(primary_stream_, aux_stream_, gpu_stream_sync_event_); + } + /*! + * \brief Makes future primary stream work wait on the completion of existing aux stream work. + */ + void PostAuxStreamUseSync() { + // If the aux_stream_ == primary_stream_, then no synchronization is necessary. + if (aux_stream_ != primary_stream_) + StreamSync(aux_stream_, primary_stream_, gpu_stream_sync_event_); + } + /*! \brief Getter for created auxiliary stream. */ + mshadow::Stream *GetStream() { return aux_stream_; } + /*! + * \brief Make future work enqueued to `s2` wait on completion of current work enqueued to `s1`. + * \param s1 stream with work that must be completed before future s2 work can begin. + * \param s2 stream whose future work is made to wait on the completion of existing s1 work. + * \param event used to pass s1 state to s2. + */ + static void StreamSync(mshadow::Stream *s1, mshadow::Stream *s2, cudaEvent_t event) { + MSHADOW_CUDA_CALL(cudaEventRecord(event, s1->stream_)); + MSHADOW_CUDA_CALL(cudaStreamWaitEvent(s2->stream_, event, 0)); + } + +private: + mshadow::Stream *primary_stream_; + mshadow::Stream *aux_stream_; + cudaEvent_t gpu_stream_sync_event_; +}; + +/*! + * \brief Provides automatic coordination of an auxilary stream with a primary one. + * This object, upon construction, prepares an aux stream for use by syncing it with enqueued + * primary-stream work. Object destruction will sync again so future primary-stream work + * will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults + * simply: the primary stream will equal the aux stream and the syncs will be executed as nops. + * See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example. + */ +class SyncedGPUAuxStream { +public: + /*! + * \brief constructor. + * \param gpu_aux_stream auxilary gpu stream that is managed by this RAII object. + */ + SyncedGPUAuxStream(GPUAuxStream *gpu_aux_stream) : gpu_aux_stream_(gpu_aux_stream) { + gpu_aux_stream_->PreAuxStreamUseSync(); + } + /*! \brief destructor */ + ~SyncedGPUAuxStream() { + gpu_aux_stream_->PostAuxStreamUseSync(); + } + /*! \brief copy constructor deleted to prevent unexpected synchronizations. */ + SyncedGPUAuxStream(const SyncedGPUAuxStream&) = delete; + /*! \brief copy assignment operator deleted to prevent unexpected synchronizations. */ + void operator=(const SyncedGPUAuxStream&) = delete; + /*! \brief move constructor permitted as alternative to copying. */ + SyncedGPUAuxStream(SyncedGPUAuxStream&&) = default; + /*! \brief move assignment operator permitted as alternative to copy assignment. */ + SyncedGPUAuxStream& operator=(SyncedGPUAuxStream&&) = default; + /*! \brief Getter for underlying mshadow::Stream. */ + inline mshadow::Stream* GetStream() const { + return gpu_aux_stream_->GetStream(); + } +private: + GPUAuxStream *gpu_aux_stream_; +}; +#endif // MXNET_USE_CUDA + /*! * \brief execution time context. * The information needed in runtime for actual execution. @@ -254,22 +359,15 @@ struct RunContext { inline mshadow::Stream* get_stream() const { return static_cast*>(stream); } +#if MXNET_USE_CUDA /*! - * \brief get the auxiliary (i.e. 2nd) mshadow stream from the Context - * The user must sync work enqueued to the aux stream with the primary stream using events. - * For example, gpu kernels that produce the inputs to the operation are likely to be - * launched into the primary stream. An event must be used to synchronize a kernel launched - * into the auxilary stream on the availability of these inputs. Also, a second event must be - * used to synchronize auxiliary stream kernel outputs with the primary stream, so that a - * subsequent operation's kernels launched into the primary stream read their inputs only when - * they are ready. See ./src/operator/nn/cudnn/cudnn_convolution-inl.h for an example. - * \return the mshadow stream - * \tparam xpu the device type of the stream + * \brief get an RAII object that transparently handles the syncing of the auxiliary stream. + * \return the aux stream auto-syncing object */ - template - inline mshadow::Stream* get_aux_stream() const { - return static_cast*>(aux_stream); + inline SyncedGPUAuxStream get_gpu_aux_stream() const { + return SyncedGPUAuxStream(static_cast(aux_stream)); } +#endif /*! \brief get the base Context from RunContext */ inline const Context& get_ctx() const { return ctx; diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 25b95445213e..22bba301221d 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -83,10 +83,15 @@ struct OpContext { inline mshadow::Stream* get_stream() const { return run_ctx.get_stream(); } - template - inline mshadow::Stream* get_aux_stream() const { - return run_ctx.get_aux_stream(); +#if MXNET_USE_CUDA + /*! + * \brief get auxilary gpu stream auto-syncing object from Context + * \return the aux stream auto-syncing object + */ + inline SyncedGPUAuxStream get_gpu_aux_stream() const { + return run_ctx.get_gpu_aux_stream(); } +#endif }; /*! \brief the execution type of the operator */ diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 75fcfa91cef1..8ce0b973f007 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -74,6 +74,12 @@ class NaiveEngine final : public Engine { streams_[i] = nullptr; } } + for (size_t i = 0; i < aux_streams_.size(); ++i) { + if (aux_streams_[i] != nullptr) { + delete aux_streams_[i]; + aux_streams_[i] = nullptr; + } + } #endif } @@ -172,8 +178,9 @@ class NaiveEngine final : public Engine { } if (streams_[dev_id] == nullptr) { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); + aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]); } - exec_fun(RunContext{exec_ctx, streams_[dev_id], streams_[dev_id], false}, callback); + exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id], false}, callback); #else LOG(FATAL) << "GPU is not enabled"; #endif @@ -220,6 +227,10 @@ class NaiveEngine final : public Engine { mshadow::Stream cpu_stream_; // GPU streams std::vector*> streams_; +#if MXNET_USE_CUDA + // GPU auxiliary streams + std::vector aux_streams_; +#endif }; // class NaiveEngine Engine *CreateNaiveEngine() { diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 56a5260a4f76..f1258f8eaafc 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -55,7 +55,7 @@ class StreamManager { #if MXNET_USE_CUDA std::array*, kStreams>, kNumGpus> gpu_streams_; - std::array*, kStreams>, kNumGpus> + std::array, kNumGpus> gpu_aux_streams_; std::array*, kNumGpus> gpu_io_streams_; std::array gpu_cnt_; @@ -79,12 +79,12 @@ RunContext StreamManager::GetRunContext( auto&& counter = gpu_cnt_.at(ctx.dev_id); if (counter == -1) { mxnet::common::cuda::DeviceStore device_store(ctx.dev_id); - for (auto&& i : gpu_streams_.at(ctx.dev_id)) { - i = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); - } - for (auto&& i : gpu_aux_streams_.at(ctx.dev_id)) { - i = Context::GetGPUStreamsPerWorker() >= 2 ? - mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id) : nullptr; + auto dev_streams = gpu_streams_.at(ctx.dev_id); + auto dev_aux_streams = gpu_aux_streams_.at(ctx.dev_id); + for (int i = 0; i != dev_streams.size(); ++i) { + auto primary_stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); + dev_streams.at(i) = primary_stream; + dev_aux_streams.at(i) = new GPUAuxStream(primary_stream); } counter = 0; } @@ -152,9 +152,12 @@ void StreamManager::Finalize() { #if MXNET_USE_CUDA for (std::size_t i = 0; i < kNumGpus; ++i) { if (gpu_cnt_.at(i) != -1) { - for (auto&& j : gpu_streams_.at(i)) { + for (auto&& primary_stream : gpu_streams_.at(i)) { // Catch exception for CUDA driver shutdown - MSHADOW_CATCH_ERROR(mshadow::DeleteStream(j)); + MSHADOW_CATCH_ERROR(mshadow::DeleteStream(primary_stream)); + } + for (auto&& aux_stream : gpu_aux_streams_.at(i)) { + delete aux_stream; } gpu_cnt_.at(i) = -1; } diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index a9536ab55a12..bcb101e9e1bb 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -245,7 +245,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { #if MXNET_USE_CUDA CHECK(block != nullptr); mshadow::Stream *stream = nullptr; - mshadow::Stream *aux_stream = nullptr; + GPUAuxStream *aux_stream = nullptr; do { ThreadPool::SetReadyOnDestroy setReady(ready_event); // allocate stream @@ -254,9 +254,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { stream = mshadow::NewStream(false, false, ctx.dev_id); } else { stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); - if (Context::GetGPUStreamsPerWorker() >= 2) { - aux_stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); - } + aux_stream = new GPUAuxStream(stream); } } while (false); // execute task @@ -273,6 +271,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream(stream)); + if (aux_stream != nullptr) + delete aux_stream; #else ready_event->signal(); #endif diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 533e47f2458a..f68d2e3e8ead 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -54,10 +54,6 @@ class CuDNNConvolutionOp { CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_)); CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_w_)); parallelize_backward_kernels_ = Context::GetGPUStreamsPerWorker() >= 2; - if (parallelize_backward_kernels_) { - CUDA_CALL(cudaEventCreateWithFlags(&dgrad_can_start_, cudaEventDisableTiming)); - CUDA_CALL(cudaEventCreateWithFlags(&dgrad_completion_, cudaEventDisableTiming)); - } } void Init(const ConvolutionParam& param, @@ -126,10 +122,6 @@ class CuDNNConvolutionOp { CUDNN_CALL(cudnnDestroyConvolutionDescriptor(forward_conv_desc_)); CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_)); CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_w_)); - if (parallelize_backward_kernels_) { - CUDA_CALL(cudaEventDestroy(dgrad_can_start_)); - CUDA_CALL(cudaEventDestroy(dgrad_completion_)); - } } void Forward(const OpContext &ctx, @@ -233,14 +225,8 @@ class CuDNNConvolutionOp { CHECK_EQ(in_data.size(), expected); CHECK_EQ(in_grad.size(), expected); Stream *s = ctx.get_stream(); - Stream *s_dgrad = parallelize_backward_kernels_ ? ctx.get_aux_stream() : s; - - // Make sure the dgrad kernel in the aux stream doesn't start before it would have - // had it been launched into the operator's primary stream. - if (parallelize_backward_kernels_ && req[conv::kData] != kNullOp) { - CUDA_CALL(cudaEventRecord(dgrad_can_start_, s->stream_)); - CUDA_CALL(cudaStreamWaitEvent(s_dgrad->stream_, dgrad_can_start_, 0)); - } + // RAII object to handle syncing of the underlying auxiliary stream with the primary stream + SyncedGPUAuxStream s_dgrad = ctx.get_gpu_aux_stream(); // I/O's should have 2 more dims than the kernel dim DType *grad_ptr = GetNdPtr(out_grad[conv::kOut], param_.kernel.ndim() + 2, s); @@ -301,7 +287,7 @@ class CuDNNConvolutionOp { gwmat_ptr)); } if (req[conv::kData] != kNullOp) { - CUDNN_CALL(cudnnConvolutionBackwardData(s_dgrad->dnn_handle_, + CUDNN_CALL(cudnnConvolutionBackwardData(s_dgrad.GetStream()->dnn_handle_, &alpha, filter_desc_, wmat_ptr, @@ -314,10 +300,6 @@ class CuDNNConvolutionOp { req[conv::kData] == kAddTo? &beta_add : &beta, in_desc_, gdata_ptr)); - if (parallelize_backward_kernels_) { - CUDA_CALL(cudaEventRecord(dgrad_completion_, s_dgrad->stream_)); - CUDA_CALL(cudaStreamWaitEvent(s->stream_, dgrad_completion_, 0)) - } } #else for (uint32_t g = 0; g < param_.num_group; ++g) { @@ -1108,10 +1090,6 @@ class CuDNNConvolutionOp { cudnnConvolutionDescriptor_t back_conv_desc_w_; // Should dgrad and wgrad be launched into separate streams bool parallelize_backward_kernels_; - // Event to signal dgrad kernel aux stream completion back to the main stream of this operator. - cudaEvent_t dgrad_completion_; - // Event from the main stream of this operator that the dgrad kernel can begin in the aux stream. - cudaEvent_t dgrad_can_start_; // Algorithm for the forward inference operation CuDNNAlgo forward_algo_; // Algorithm for the back-prop operation to the data From db5075f4aa95c201c6238b75f8b1c1e82aa4d3d4 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Sat, 9 Feb 2019 17:47:27 -0800 Subject: [PATCH 08/10] Fix cpplint. --- include/mxnet/base.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 908e829742fa..16407963ff7c 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -229,7 +229,7 @@ struct Context { #if MXNET_USE_CUDA /*! \brief Holds an auxiliary mshadow gpu stream that can be synced with a primary stream. */ class GPUAuxStream { -public: + public: /*! * \brief constructor. * \param primary_stream gpu stream that is synced with the created auxiliary stream. @@ -287,7 +287,7 @@ class GPUAuxStream { MSHADOW_CUDA_CALL(cudaStreamWaitEvent(s2->stream_, event, 0)); } -private: + private: mshadow::Stream *primary_stream_; mshadow::Stream *aux_stream_; cudaEvent_t gpu_stream_sync_event_; @@ -302,12 +302,12 @@ class GPUAuxStream { * See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example. */ class SyncedGPUAuxStream { -public: + public: /*! * \brief constructor. * \param gpu_aux_stream auxilary gpu stream that is managed by this RAII object. */ - SyncedGPUAuxStream(GPUAuxStream *gpu_aux_stream) : gpu_aux_stream_(gpu_aux_stream) { + explicit SyncedGPUAuxStream(GPUAuxStream *gpu_aux_stream) : gpu_aux_stream_(gpu_aux_stream) { gpu_aux_stream_->PreAuxStreamUseSync(); } /*! \brief destructor */ @@ -326,7 +326,8 @@ class SyncedGPUAuxStream { inline mshadow::Stream* GetStream() const { return gpu_aux_stream_->GetStream(); } -private: + + private: GPUAuxStream *gpu_aux_stream_; }; #endif // MXNET_USE_CUDA From ea56c9105952d6e992ba539b3f6d7e414522a616 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Sun, 10 Feb 2019 18:49:47 -0800 Subject: [PATCH 09/10] Expand testing to cover all engines. --- src/engine/naive_engine.cc | 1 + src/engine/stream_manager.h | 13 +++++++------ tests/python/gpu/test_operator_gpu.py | 13 ++++++++----- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 8ce0b973f007..61ead23cf526 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -175,6 +175,7 @@ class NaiveEngine final : public Engine { MSHADOW_CATCH_ERROR(mshadow::SetDevice(exec_ctx.dev_id)); if (streams_.size() <= dev_id) { streams_.resize(dev_id + 1, nullptr); + aux_streams_.resize(dev_id + 1, nullptr); } if (streams_[dev_id] == nullptr) { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index f1258f8eaafc..42d03e55a275 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -79,12 +79,13 @@ RunContext StreamManager::GetRunContext( auto&& counter = gpu_cnt_.at(ctx.dev_id); if (counter == -1) { mxnet::common::cuda::DeviceStore device_store(ctx.dev_id); - auto dev_streams = gpu_streams_.at(ctx.dev_id); - auto dev_aux_streams = gpu_aux_streams_.at(ctx.dev_id); - for (int i = 0; i != dev_streams.size(); ++i) { - auto primary_stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); - dev_streams.at(i) = primary_stream; - dev_aux_streams.at(i) = new GPUAuxStream(primary_stream); + for (auto&& primary_stream : gpu_streams_.at(ctx.dev_id)) { + primary_stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); + } + int idx = 0; + for (auto&& aux_stream : gpu_aux_streams_.at(ctx.dev_id)) { + auto primary_stream = gpu_streams_.at(ctx.dev_id).at(idx++); + aux_stream = new GPUAuxStream(primary_stream); } counter = 0; } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 237df55652d3..b84879fa5b5f 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -523,7 +523,7 @@ def test_convolution_options(): # Helper function to run tests in a subprocess to avoid save/restore of os.environ. # Also avoids issues of cached environment variable lookups in the backend. -def _test_in_separate_process(func, *args): +def _test_in_separate_process(func, env, *args): try: mpctx = mp.get_context('spawn') except: @@ -531,14 +531,15 @@ def _test_in_separate_process(func, *args): sys.version_info[0:2], file=sys.stderr, end='') else: seed = np.random.randint(0,1024*1024*1024) + for (key, value) in env.items(): + os.environ[key] = str(value) # Prepend seed as first arg p = mpctx.Process(target=func, args=(seed,)+args) p.start() p.join() assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) -def _conv_with_num_streams(seed, num_streams): - os.environ['MXNET_GPU_WORKER_NSTREAMS'] = str(num_streams) +def _conv_with_num_streams(seed): with random_seed(seed): # Try to expose timing-dependent improper workspace sharing by parallel dgrad and wgrad num_trials = 20 @@ -563,8 +564,10 @@ def _conv_with_num_streams(seed, num_streams): @with_seed() def test_convolution_multiple_streams(): - _test_in_separate_process(_conv_with_num_streams, 1) - _test_in_separate_process(_conv_with_num_streams, 2) + for num_streams in [1, 2]: + for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']: + _test_in_separate_process(_conv_with_num_streams, + {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine}) # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c. From 790a9981440424b6646a7f83f73cce7043ab2398 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Wed, 13 Feb 2019 19:32:12 -0800 Subject: [PATCH 10/10] Fix NaiveEngine shutdown segfault on CentOS7. --- src/engine/naive_engine.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 61ead23cf526..db4491981bdd 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -62,6 +62,8 @@ class NaiveEngine final : public Engine { }; NaiveEngine() { + objpool_opr_ref_ = common::ObjectPool::_GetSharedRef(); + objpool_var_ref_ = common::ObjectPool::_GetSharedRef(); } // virtual destructor virtual ~NaiveEngine() { @@ -232,6 +234,13 @@ class NaiveEngine final : public Engine { // GPU auxiliary streams std::vector aux_streams_; #endif +/*! + * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early + * See also #309 (https://github.com/dmlc/mxnet/issues/309) and similar fix in threaded_engine.h. + * Without this, segfaults seen on CentOS7 in test_operator_gpu.py:test_convolution_multiple_streams + */ + std::shared_ptr > objpool_opr_ref_; + std::shared_ptr > objpool_var_ref_; }; // class NaiveEngine Engine *CreateNaiveEngine() {