From 5011145983a78a365fc915efbb9285802fb4e02b Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Fri, 26 Oct 2018 00:03:41 -0700 Subject: [PATCH] CudnnFind() usage improvements (#12804) * Add mx.context.gpu_memory_info() to python api for flexible tests. * Add test_gluon_gpu.py:test_large_models to show cudnnFind headroom issue. * Output model sizes tried by test_gluon_gpu.py:test_large_models. * Fix perl interface to MXGetGPUMemoryInformation. * Increase difficulty of test_gluon_gpu.py:test_large_models. * Forgot a file in fix for perl. * Modify test to pass on no-cudnn CI runner. * Mutex algo reg updates, serialize cudnnFind calls. * Fix for cudnnFind memory headroom issue. * Fix cpplint. * Respond to reviewers comments. * Guard against improper MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE values. * Fix potentially unassigned var. --- CONTRIBUTORS.md | 1 + docs/faq/env_var.md | 4 + include/mxnet/base.h | 14 +- include/mxnet/c_api.h | 10 + perl-package/AI-MXNetCAPI/mxnet.i | 10 + python/mxnet/context.py | 24 + src/c_api/c_api.cc | 11 + src/operator/nn/cudnn/cudnn_algoreg-inl.h | 66 +-- src/operator/nn/cudnn/cudnn_convolution-inl.h | 480 +++++++++-------- .../nn/cudnn/cudnn_deconvolution-inl.h | 502 ++++++++++-------- src/storage/pooled_storage_manager.h | 30 +- tests/python/gpu/test_gluon_gpu.py | 46 +- 12 files changed, 707 insertions(+), 491 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 5e5e76d52aae..404f135cd91e 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -187,3 +187,4 @@ List of Contributors * [LuckyPigeon](https://github.com/LuckyPigeon) * [Anton Chernov](https://github.com/lebeg) * [Denisa Roberts](https://github.com/D-Roberts) +* [Dick Carter](https://github.com/DickJC123) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 0464c73fadcb..92cf4931b04b 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -67,6 +67,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF - Values: Int ```(default=24)``` - The cutoff threshold that decides the rounding strategy. Let's denote the threshold as T. If the memory size is smaller than `2 ** T` (by default, it's 2 ** 24 = 16MB), it rounds to the smallest `2 ** n` that is larger than the requested memory size; if the memory size is larger than `2 ** T`, it rounds to the next k * 2 ** T. +* MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE + - Values: Int ```(default=2097152)``` + - When using the naive pool type, memory allocations larger than this threshhold are rounded up to a multiple of this value. + - The default was chosen to minimize global memory fragmentation within the GPU driver. Set this to 1 to disable. ## Engine Type diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 783f74ab447a..5f16eb441868 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -225,11 +225,11 @@ struct Context { /*! * \brief get the free and total available memory on a GPU * \param dev the GPU number to query - * \param free_mem pointer to the integer holding free GPU memory - * \param total_mem pointer to the integer holding total GPU memory + * \param free_mem pointer to the uint64_t holding free GPU memory + * \param total_mem pointer to the uint64_t holding total GPU memory * \return No return value */ - inline static void GetGPUMemoryInformation(int dev, int *free, int *total); + inline static void GetGPUMemoryInformation(int dev, uint64_t *free, uint64_t *total); /*! * Create a pinned CPU context. * \param dev_id the device id for corresponding GPU. @@ -334,8 +334,8 @@ inline int32_t Context::GetGPUCount() { #endif } -inline void Context::GetGPUMemoryInformation(int dev, int *free_mem, - int *total_mem) { +inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem, + uint64_t *total_mem) { #if MXNET_USE_CUDA size_t memF, memT; @@ -354,8 +354,8 @@ inline void Context::GetGPUMemoryInformation(int dev, int *free_mem, e = cudaSetDevice(curDevice); CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); - *free_mem = static_cast(memF); - *total_mem = static_cast(memT); + *free_mem = static_cast(memF); + *total_mem = static_cast(memT); #else LOG(FATAL) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index dc33c95437f9..e9f1e2d6cccc 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -441,6 +441,7 @@ MXNET_DLL int MXGetGPUCount(int* out); /*! * \brief get the free and total available memory on a GPU + * Note: Deprecated, use MXGetGPUMemoryInformation64 instead. * \param dev the GPU number to query * \param free_mem pointer to the integer holding free GPU memory * \param total_mem pointer to the integer holding total GPU memory @@ -448,6 +449,15 @@ MXNET_DLL int MXGetGPUCount(int* out); */ MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem); +/*! + * \brief get the free and total available memory on a GPU + * \param dev the GPU number to query + * \param free_mem pointer to the uint64_t holding free GPU memory + * \param total_mem pointer to the uint64_t holding total GPU memory + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem); + /*! * \brief get the MXNet library version as an integer * \param pointer to the integer holding the version number diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index 38665748a0bf..b1907f5cd7ec 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -344,6 +344,7 @@ int MXGetGPUCount(int* out); /*! * \brief get the free and total available memory on a GPU + * Note: deprecated, use MXGetGPUMemoryInformation64(). * \param dev the GPU number to query * \param free_mem pointer to the integer holding free GPU memory * \param total_mem pointer to the integer holding total GPU memory @@ -351,6 +352,15 @@ int MXGetGPUCount(int* out); */ int MXGetGPUMemoryInformation(int dev, int *out, int *out); +/*! + * \brief get the free and total available memory on a GPU + * \param dev the GPU number to query + * \param free_mem pointer to the uint64_t holding free GPU memory + * \param total_mem pointer to the uint64_t holding total GPU memory + * \return 0 when success, -1 when failure happens + */ +int MXGetGPUMemoryInformation64(int dev, uint64_t *out, uint64_t *out); + //------------------------------------- // Part 1: NDArray creation and deletion diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 61b70532dd74..15ea9905de03 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -258,6 +258,30 @@ def num_gpus(): check_call(_LIB.MXGetGPUCount(ctypes.byref(count))) return count.value +def gpu_memory_info(device_id=0): + """Query CUDA for the free and total bytes of GPU global memory. + + Parameters + ---------- + device_id : int, optional + The device id of the GPU device. + + Raises + ------ + Will raise an exception on any CUDA error. + + Returns + ------- + (free, total) : (int, int) + The number of GPUs. + + """ + free = ctypes.c_uint64() + total = ctypes.c_uint64() + dev_id = ctypes.c_int(device_id) + check_call(_LIB.MXGetGPUMemoryInformation64(dev_id, ctypes.byref(free), ctypes.byref(total))) + return (free.value, total.value) + def current_context(): """Returns the current context. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 56e318097a3c..80bd60538ff5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -122,7 +122,18 @@ int MXGetGPUCount(int* out) { API_END(); } +// Deprecated: use MXGetGPUMemoryInformation64() instead. int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) { + API_BEGIN(); + uint64_t free_mem64 = 0UL; + uint64_t total_mem64 = 0UL; + Context::GetGPUMemoryInformation(dev, &free_mem64, &total_mem64); + *free_mem = static_cast(free_mem64); + *total_mem = static_cast(total_mem64); + API_END(); +} + +int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem) { API_BEGIN(); Context::GetGPUMemoryInformation(dev, free_mem, total_mem); API_END(); diff --git a/src/operator/nn/cudnn/cudnn_algoreg-inl.h b/src/operator/nn/cudnn/cudnn_algoreg-inl.h index 3b59fd1c3ced..21d3a30ba7cd 100644 --- a/src/operator/nn/cudnn/cudnn_algoreg-inl.h +++ b/src/operator/nn/cudnn/cudnn_algoreg-inl.h @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include "../../../common/cuda_utils.h" #include "../convolution-inl.h" #include "../deconvolution-inl.h" @@ -65,7 +67,11 @@ class CuDNNAlgo { template class CuDNNAlgoReg { public: - bool Find(const ParamType ¶m, + using AlgoSetter_t = std::function *, + CuDNNAlgo *, + CuDNNAlgo *)>; + + void FindOrElseRegister(const ParamType ¶m, const std::vector &in_shape, const std::vector &out_shape, cudnnDataType_t cudnn_data_type, @@ -75,7 +81,8 @@ class CuDNNAlgoReg { bool add_to_weight, CuDNNAlgo *fwd, CuDNNAlgo *bwd, - CuDNNAlgo *flt) { + CuDNNAlgo *flt, + const AlgoSetter_t &algo_setter) { CHECK(in_shape.size() == 2 || in_shape.size() == 3); ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type, cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch, add_to_weight}; @@ -85,45 +92,28 @@ class CuDNNAlgoReg { *fwd = i->second.fwd; *bwd = i->second.bwd; *flt = i->second.flt; - return true; - } - return false; - } - - void Register(const ParamType ¶m, - const std::vector &in_shape, - const std::vector &out_shape, - cudnnDataType_t cudnn_data_type, - cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type, - int sm_arch, - bool add_to_weight, - const CuDNNAlgo &fwd, - const CuDNNAlgo &bwd, - const CuDNNAlgo &flt) { - CHECK(in_shape.size() == 2 || in_shape.size() == 3); - ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type, - cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch, add_to_weight}; - std::lock_guard guard(lock_); - if (param.cudnn_tune.value() && reg_.size() % 50 == 0) { - LOG(INFO) << "Running performance tests to find the best convolution " - "algorithm, " - "this can take a while... (setting env variable " - "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)"; - if (reg_.size() >= 1000) { - // Many people are very concerned about this warning, so change the warning once. - if (!is_warning_autotune_) { - LOG(INFO) - << "If you see this message in the middle of training, you are " - "probably using bucketing. Consider setting env variable " - "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning."; - is_warning_autotune_ = true; + } else { + if (param.cudnn_tune.value() && reg_.size() % 50 == 0) { + LOG(INFO) << "Running performance tests to find the best convolution " + "algorithm, " + "this can take a while... (setting env variable " + "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)"; + if (reg_.size() >= 1000) { + // Many people are very concerned about this warning, so change the warning once. + if (!is_warning_autotune_) { + LOG(INFO) + << "If you see this message in the middle of training, you are " + "probably using bucketing. Consider setting env variable " + "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning."; + is_warning_autotune_ = true; + } } } + // Call provided function to determine the algos- likely uses cudnnFind() or cudnnGet() + algo_setter(fwd, bwd, flt); + // Save result so future lookups hit in this registry + reg_.insert(std::pair(key, CudnnAlgorithms{*fwd, *bwd, *flt})); } - reg_[key].fwd = fwd; - reg_[key].bwd = bwd; - reg_[key].flt = flt; } static CuDNNAlgoReg *Get(); diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 53bd76c9c3e8..d63d46821edc 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -26,6 +26,7 @@ #ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_ #define MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_ +#include #include #include #include @@ -611,236 +612,265 @@ class CuDNNConvolutionOp { } } - void SelectAlgo(const RunContext& rctx, + void CuDNNAlgoSetter(const RunContext& rctx, const std::vector& in_shape, const std::vector& out_shape, cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type) { - if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, cudnn_backward_compute_type, - SMArch(rctx.ctx.dev_id), add_to_weight_, - &forward_algo_, &back_algo_, &back_algo_w_)) { - mshadow::Stream *s = rctx.get_stream(); - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); - #if CUDNN_MAJOR >= 7 - // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire - // story: the notion of whether the algo ran in Tensor Core mode is not known. - // Since we want to report the Tensor Core mode in the verbose output, we switch - // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches - // that of *Find*(), we can unify the find-vs-get logic by using function pointers. - - // Forward Algorithm Find/Get() v7 - std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); - int actual_fwd_algos = 0; - auto fwd_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 - : cudnnFindConvolutionForwardAlgorithm; - CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - fwd_results.size(), - &actual_fwd_algos, - fwd_results.data())); - fwd_results.resize(actual_fwd_algos); - AlgoFinalSelect(fwd_results, "forward", - workspace_byte, &forward_algo_); - - // Backprop-to-Filter Algorithm Find/Get() v7 - auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); - std::vector bwd_filt_results(max_bwd_filt_algos); - int actual_bwd_filter_algos = 0; - // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we - // were summing into the output (i.e. beta != 0). Get() returned OK algos though. - auto bwd_filter_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; - CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, + cudnnDataType_t cudnn_backward_compute_type, + CuDNNAlgo *fwd, + CuDNNAlgo *bwd, + CuDNNAlgo *flt) { + // Not in algo registry, must determine via *Get*() or *Find*() + mshadow::Stream *s = rctx.get_stream(); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); +#if CUDNN_MAJOR >= 7 + // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire + // story: the notion of whether the algo ran in Tensor Core mode is not known. + // Since we want to report the Tensor Core mode in the verbose output, we switch + // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches + // that of *Find*(), we can unify the find-vs-get logic by using function pointers. + + // Forward Algorithm Find/Get() v7 + std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); + int actual_fwd_algos = 0; + auto fwd_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 + : cudnnFindConvolutionForwardAlgorithm; + CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, + in_desc_, + filter_desc_, + forward_conv_desc_, + out_desc_, + fwd_results.size(), + &actual_fwd_algos, + fwd_results.data())); + fwd_results.resize(actual_fwd_algos); + AlgoFinalSelect(fwd_results, "forward", + workspace_byte, fwd); + + // Backprop-to-Filter Algorithm Find/Get() v7 + auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); + std::vector bwd_filt_results(max_bwd_filt_algos); + int actual_bwd_filter_algos = 0; + // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we + // were summing into the output (i.e. beta != 0). Get() returned OK algos though. + auto bwd_filter_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 + : cudnnFindConvolutionBackwardFilterAlgorithm; + CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, + in_desc_, + out_desc_, + back_conv_desc_w_, + filter_desc_, + bwd_filt_results.size(), + &actual_bwd_filter_algos, + bwd_filt_results.data())); + bwd_filt_results.resize(actual_bwd_filter_algos); + AlgoFinalSelect(bwd_filt_results, "backprop-to-filter", + workspace_byte, flt); + + // Backprop-to-Data Algorithm Find/Get() v7 + auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); + std::vector bwd_data_results(max_bwd_data_algos); + int actual_bwd_data_algos = 0; + auto bwd_data_algo_discoverer = + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 + : cudnnFindConvolutionBackwardDataAlgorithm; + CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, + filter_desc_, + out_desc_, + back_conv_desc_, + in_desc_, + bwd_data_results.size(), + &actual_bwd_data_algos, + bwd_data_results.data())); + bwd_data_results.resize(actual_bwd_data_algos); + AlgoFinalSelect(bwd_data_results, "backprop-to-data", + workspace_byte, bwd); +#else + // CUDNN_MAJOR < 7 + const int kMaxAlgos = 10; + int nalgo = kMaxAlgos; + int i = 0; + size_t min_memory_needs = 0; + // Forward Algorithm Find/Get, v6 and earlier + if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { + // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is + // supported. Hard-coded this since the algo find() or get() throws an FPE. + fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); + } else if (!param_.cudnn_tune.value()) { + cudnnConvolutionFwdAlgo_t fastest_fwd_algo; + CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, in_desc_, - out_desc_, - back_conv_desc_w_, filter_desc_, - bwd_filt_results.size(), - &actual_bwd_filter_algos, - bwd_filt_results.data())); - bwd_filt_results.resize(actual_bwd_filter_algos); - AlgoFinalSelect(bwd_filt_results, "backprop-to-filter", - workspace_byte, &back_algo_w_); - - // Backprop-to-Data Algorithm Find/Get() v7 - auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); - std::vector bwd_data_results(max_bwd_data_algos); - int actual_bwd_data_algos = 0; - auto bwd_data_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 - : cudnnFindConvolutionBackwardDataAlgorithm; - CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - bwd_data_results.size(), - &actual_bwd_data_algos, - bwd_data_results.data())); - bwd_data_results.resize(actual_bwd_data_algos); - AlgoFinalSelect(bwd_data_results, "backprop-to-data", - workspace_byte, &back_algo_); - #else - // CUDNN_MAJOR < 7 - const int kMaxAlgos = 10; - int nalgo = kMaxAlgos; - int i = 0; - size_t min_memory_needs = 0; - // Forward Algorithm Find/Get, v6 and earlier - if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { - // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is - // supported. Hard-coded this since the algo find() or get() throws an FPE. - forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); - } else if (!param_.cudnn_tune.value()) { - cudnnConvolutionFwdAlgo_t fastest_fwd_algo; - CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_fwd_algo)); - forward_algo_.Set(fastest_fwd_algo, false); - } else { - cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - kMaxAlgos, - &nalgo, - fwd_algo)); - i = 0; - while (i < nalgo - && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited - && fwd_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = - (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, fwd_algo[i].memory); - } - if (i == nalgo) { - LOG(FATAL) << nalgo << " forward algorithms with minimum memory requirement " - << min_memory_needs << " bytes have been tried. Workspace size is set to " - << workspace_byte << " bytes, please consider reducing the batch/model size, " - << "or increasing workspace size."; - } else { - forward_algo_.Set(fwd_algo[i].algo, false); - } + forward_conv_desc_, + out_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_fwd_algo)); + fwd->Set(fastest_fwd_algo, false); + } else { + cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, + in_desc_, + filter_desc_, + forward_conv_desc_, + out_desc_, + kMaxAlgos, + &nalgo, + fwd_algo)); + i = 0; + while (i < nalgo + && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == conv::kLimited + && fwd_algo[i].memory > workspace_byte))) { + ++i; + min_memory_needs = + (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, fwd_algo[i].memory); } - // Backprop-to-Filter Algorithm Find/Get, v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_filt_algo)); - back_algo_w_.Set(fastest_bwd_filt_algo, false); + if (i == nalgo) { + LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, "forward"); } else { - cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo)); - i = 0; - while (i < nalgo - && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited - && bwd_filter_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_filter_algo[i].memory : - std::min(min_memory_needs, bwd_filter_algo[i].memory); - } - if (i == nalgo) { - LOG(FATAL) << nalgo << " backward filter algorithms with minimum memory requirement " - << min_memory_needs << " bytes have been tried. Workspace size is set to " - << workspace_byte << " bytes, please consider reducing the batch/model size, " - << "or increasing workspace size."; - } else { - back_algo_w_.Set(bwd_filter_algo[i].algo, false); - } + fwd->Set(fwd_algo[i].algo, false); } - // Backprop-to-Data Algorithm Get(), v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_data_algo)); - back_algo_.Set(fastest_bwd_data_algo, false); - } else { - cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, + } + // Backprop-to-Filter Algorithm Find/Get, v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + in_desc_, + out_desc_, + back_conv_desc_w_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_filt_algo)); + flt->Set(fastest_bwd_filt_algo, false); + } else { + cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, in_desc_, + out_desc_, + back_conv_desc_w_, + filter_desc_, kMaxAlgos, &nalgo, - bwd_data_algo)); - i = 0; - while (i < nalgo - && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == conv::kLimited - && bwd_data_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_data_algo[i].memory : - std::min(min_memory_needs, bwd_data_algo[i].memory); - } - if (i == nalgo) { - LOG(FATAL) << nalgo << " backward data algorithms with minimum memory requirement " - << min_memory_needs << " bytes have been tried. Workspace size is set to " - << workspace_byte << " bytes, please consider reducing the batch/model size, " - << "or increasing workspace size."; - } else { - back_algo_.Set(bwd_data_algo[i].algo, false); - } + bwd_filter_algo)); + i = 0; + while (i < nalgo + && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == conv::kLimited + && bwd_filter_algo[i].memory > workspace_byte))) { + ++i; + min_memory_needs = (i == 0) ? + bwd_filter_algo[i].memory : + std::min(min_memory_needs, bwd_filter_algo[i].memory); } - #endif // CUDNN_MAJOR < 7 - - // Fix for issue #11241 - int cudnn_find_issue_max_features = 64 * 1024; - if (add_to_weight_ && Features(in_shape[conv::kData]) >= cudnn_find_issue_max_features) { - this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); + if (i == nalgo) { + LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, "backward filter"); + } else { + flt->Set(bwd_filter_algo[i].algo, false); } + } + // Backprop-to-Data Algorithm Get(), v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + out_desc_, + back_conv_desc_, + in_desc_, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_data_algo)); + bwd->Set(fastest_bwd_data_algo, false); + } else { + cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + out_desc_, + back_conv_desc_, + in_desc_, + kMaxAlgos, + &nalgo, + bwd_data_algo)); + i = 0; + while (i < nalgo + && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == conv::kLimited + && bwd_data_algo[i].memory > workspace_byte))) { + ++i; + min_memory_needs = (i == 0) ? + bwd_data_algo[i].memory : + std::min(min_memory_needs, bwd_data_algo[i].memory); + } + if (i == nalgo) { + LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, "backward data"); + } else { + bwd->Set(bwd_data_algo[i].algo, false); + } + } +#endif // CUDNN_MAJOR < 7 - // An algo specification by the user may be cached here, but another - // convolution will match only if identically specified. - // We're caching results of *Get* as well as *Find*, but these records - // will be held distinctly because param_.cudnn_tune is part of the key. - CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(rctx.ctx.dev_id), this->add_to_weight_, - this->forward_algo_, - this->back_algo_, this->back_algo_w_); + // Fix for issue #11241 + int cudnn_find_issue_max_features = 64 * 1024; + if (add_to_weight_ && Features(in_shape[conv::kData]) >= cudnn_find_issue_max_features) { + flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); } + } + + void SelectAlgo(const RunContext& rctx, + const std::vector& in_shape, + const std::vector& out_shape, + cudnnDataType_t cudnn_forward_compute_type, + cudnnDataType_t cudnn_backward_compute_type) { + auto algo_setter = [&](CuDNNAlgo *fwd, + CuDNNAlgo *bwd, + CuDNNAlgo *flt) { + if (param_.cudnn_tune.value() == conv::kOff) { + // The routine will only be calling cudnnGet, so no need to grab the Storage lock. + this->CuDNNAlgoSetter(rctx, in_shape, out_shape, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + fwd, bwd, flt); + } else { + // One potential problem is that cudnnFind() uses cudaMalloc() to directly allocate + // I/O and workspace areas, and these allocations may result in an out-of-memory + // error even though the StorageMangager free pool is not empty. Ideally, cudnnFind + // would use MXNet's storage allocator for its I/O and workspace areas, instead of using + // the area carved out by MXNET_GPU_MEM_POOL_RESERVE. + // To get somewhat the same effect as this, we can pre-allocate the areas needed for the + // I/Os (possibly triggering a desirable StorageManager::ReleaseAll()), followed by a + // DirectFree(), which makes these areas available for cudnn's subsequent cudaMalloc(). + + // Allocate for x (or dx), w (or dw) and y (or dy). + ReserveElements({in_shape[conv::kData].Size(), + in_shape[conv::kWeight].Size(), + out_shape[conv::kOut].Size()}); + + // We're about to call cudnnFind so we need to quiet the system by grabbing + // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate timing + // measurements of the algos, and can prevent the cuda driver's proper freeing + // of cudnnFind's internal temporary allocations. Grabbing the lock might also + // impede other threads from launching work on the GPU. + std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); + this->CuDNNAlgoSetter(rctx, in_shape, out_shape, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + fwd, bwd, flt); + } + }; + + CuDNNConvAlgoReg::Get()->FindOrElseRegister(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(rctx.ctx.dev_id), add_to_weight_, + &forward_algo_, &back_algo_, &back_algo_w_, algo_setter); + // If we're allowing Tensor Core variants of the algos to be considered in // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, // we must change the descriptor to preclude Tensor Core. Simplest is to @@ -877,6 +907,7 @@ class CuDNNConvolutionOp { << " please consider reducing batch/model size or increasing the workspace size"; } + void GetTempSize(const OpContext& ctx) { mshadow::Stream *s = ctx.get_stream(); size_t back_size = 0, back_size_w = 0; @@ -975,6 +1006,25 @@ class CuDNNConvolutionOp { return c; } + // Make a number of allocations and directly free them, ensuring room for an equivalent set of + // cudaMalloc() calls by (say) cudnnFind(). `elements` spec the alloc size in DTypes, not bytes. + void ReserveElements(const std::vector &elements) { + std::vector handles; + for (size_t alloc_element : elements) + handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU())); + for (auto &handle : handles) + Storage::Get()->DirectFree(handle); + } + + // Log that no suitable algo was found that met the workspace constraints, then exit. + void LogNoSuitableAlgoAndExit(int num_algos_tried, size_t min_memory_needs, + size_t workspace_byte, std::string algo_kind) { + LOG(FATAL) << num_algos_tried << " " << algo_kind << " with minimum memory requirement " + << min_memory_needs << " bytes have been tried. Workspace size is set to " + << workspace_byte << " bytes, please consider reducing the batch/model size, " + << "or increasing workspace size."; + } + std::vector param_stride_; std::vector param_dilate_; std::vector param_pad_; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 041bea66f7bf..c0c56507bbf3 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -26,6 +26,7 @@ #ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_DECONVOLUTION_INL_H_ #define MXNET_OPERATOR_NN_CUDNN_CUDNN_DECONVOLUTION_INL_H_ +#include #include #include #include @@ -538,245 +539,273 @@ class CuDNNDeconvolutionOp { } } - void SelectAlgo(const RunContext& rctx, - const std::vector& in_shape, - const std::vector& out_shape, - cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type) { - if (!CuDNNDeconvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(rctx.ctx.dev_id), add_to_weight_, - &forward_algo_, &back_algo_, &back_algo_w_)) { - mshadow::Stream *s = rctx.get_stream(); - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); - #if CUDNN_MAJOR >= 7 - // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire - // story: the notion of whether the algo ran in Tensor Core mode is not known. - // Since we want to report the Tensor Core mode in the verbose output, we switch - // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches - // that of *Find*(), we can unify the find-vs-get logic by using function pointers. - - // Forward Algorithm Find/Get() v7 - std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); - int actual_fwd_algos = 0; - auto fwd_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 + void CuDNNAlgoSetter(const RunContext& rctx, + const std::vector& in_shape, + const std::vector& out_shape, + cudnnDataType_t cudnn_forward_compute_type, + cudnnDataType_t cudnn_backward_compute_type, + CuDNNAlgo *fwd, + CuDNNAlgo *bwd, + CuDNNAlgo *flt) { + // Not in algo registry, must determine via *Get*() or *Find*() + mshadow::Stream *s = rctx.get_stream(); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); +#if CUDNN_MAJOR >= 7 + // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire + // story: the notion of whether the algo ran in Tensor Core mode is not known. + // Since we want to report the Tensor Core mode in the verbose output, we switch + // to using the new *Get*_v7() call. Since the function signature of *Get*_v7() matches + // that of *Find*(), we can unify the find-vs-get logic by using function pointers. + + // Forward Algorithm Find/Get() v7 + std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); + int actual_fwd_algos = 0; + auto fwd_algo_discoverer = + param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm; - CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used to backprop-to-data - in_desc_, - fwd_results.size(), - &actual_fwd_algos, - fwd_results.data())); - fwd_results.resize(actual_fwd_algos); - AlgoFinalSelect(fwd_results, "forward", - workspace_byte, &forward_algo_); - - // Backprop-to-Filter Algorithm Find/Get() v7 - auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); - std::vector bwd_filt_results(max_bwd_filt_algos); - int actual_bwd_filter_algos = 0; - // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we - // were summing into the output (i.e. beta != 0). Get() returned OK algos though. - auto bwd_filter_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; - CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - bwd_filt_results.size(), - &actual_bwd_filter_algos, - bwd_filt_results.data())); - bwd_filt_results.resize(actual_bwd_filter_algos); - AlgoFinalSelect(bwd_filt_results, "backprop-to-filter", - workspace_byte, &back_algo_w_); - - // Backprop-to-Data Algorithm Find/Get() v7 - auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); - std::vector bwd_data_results(max_bwd_data_algos); - int actual_bwd_data_algos = 0; - auto bwd_data_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 + CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, + out_desc_, + filter_desc_, + back_conv_desc_, // fwd algo used to backprop-to-data + in_desc_, + fwd_results.size(), + &actual_fwd_algos, + fwd_results.data())); + fwd_results.resize(actual_fwd_algos); + AlgoFinalSelect(fwd_results, "forward", + workspace_byte, fwd); + + // Backprop-to-Filter Algorithm Find/Get() v7 + auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); + std::vector bwd_filt_results(max_bwd_filt_algos); + int actual_bwd_filter_algos = 0; + // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we + // were summing into the output (i.e. beta != 0). Get() returned OK algos though. + auto bwd_filter_algo_discoverer = + param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 + : cudnnFindConvolutionBackwardFilterAlgorithm; + CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, + out_desc_, + in_desc_, + back_conv_desc_, + filter_desc_, + bwd_filt_results.size(), + &actual_bwd_filter_algos, + bwd_filt_results.data())); + bwd_filt_results.resize(actual_bwd_filter_algos); + AlgoFinalSelect(bwd_filt_results, "backprop-to-filter", + workspace_byte, flt); + // Backprop-to-Data Algorithm Find/Get() v7 + auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); + std::vector bwd_data_results(max_bwd_data_algos); + int actual_bwd_data_algos = 0; + auto bwd_data_algo_discoverer = + param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 : cudnnFindConvolutionBackwardDataAlgorithm; - CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, + CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // bwd algo used in inference + out_desc_, + bwd_data_results.size(), + &actual_bwd_data_algos, + bwd_data_results.data())); + bwd_data_results.resize(actual_bwd_data_algos); + AlgoFinalSelect(bwd_data_results, "backprop-to-data", + workspace_byte, bwd); +#else + // CUDNN_MAJOR < 7 + const int kMaxAlgos = 10; + int nalgo = kMaxAlgos; + int i = 0; + size_t min_memory_needs = 0; + // Forward Algorithm Find/Get, v6 and earlier + if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { + // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is + // supported. Hard-coded this since the algo find() or get() throws an FPE. + fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); + } else if (!param_.cudnn_tune.value()) { + cudnnConvolutionFwdAlgo_t fastest_fwd_algo; + CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + back_conv_desc_, // fwd algo used in dgrad + in_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_fwd_algo)); + fwd->Set(fastest_fwd_algo, false); + } else { + cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + back_conv_desc_, // fwd algo used in dgrad + in_desc_, + kMaxAlgos, + &nalgo, + fwd_algo)); + i = 0; + while (i < nalgo + && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && fwd_algo[i].memory > workspace_byte))) { + ++i; + min_memory_needs = (i == 0) ? + fwd_algo[i].memory : + std::min(min_memory_needs, fwd_algo[i].memory); + } + if (i == nalgo) { + LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, + "forward algos (for use in deconv op backprop-to-data)"); + } else { + fwd->Set(fwd_algo[i].algo, false); + } + } + // Backprop-to-Filter Algorithm Find/Get, v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + back_conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_filt_algo)); + flt->Set(fastest_bwd_filt_algo, false); + } else { + cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + back_conv_desc_, + filter_desc_, + kMaxAlgos, + &nalgo, + bwd_filter_algo)); + i = 0; + while (i < nalgo + && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && bwd_filter_algo[i].memory > workspace_byte))) { + ++i; + min_memory_needs = (i == 0) ? + bwd_filter_algo[i].memory : + std::min(min_memory_needs, bwd_filter_algo[i].memory); + } + if (i == nalgo) { + LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, + "backward filter algos (for use in deconv op backprop-to-filter)"); + } else { + flt->Set(bwd_filter_algo[i].algo, false); + } + } + // Backprop-to-Data Algorithm Get(), v6 and earlier + if (!param_.cudnn_tune.value()) { + cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; + CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // bwd algo used for inference + out_desc_, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &fastest_bwd_data_algo)); + bwd->Set(fastest_bwd_data_algo, false); + } else { + cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, filter_desc_, in_desc_, forward_conv_desc_, // bwd algo used in inference out_desc_, - bwd_data_results.size(), - &actual_bwd_data_algos, - bwd_data_results.data())); - bwd_data_results.resize(actual_bwd_data_algos); - AlgoFinalSelect(bwd_data_results, "backprop-to-data", - workspace_byte, &back_algo_); - #else - // CUDNN_MAJOR < 7 - const int kMaxAlgos = 10; - int nalgo = kMaxAlgos; - int i = 0; - size_t min_memory_needs = 0; - // Forward Algorithm Find/Get, v6 and earlier - if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { - // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is - // supported. Hard-coded this since the algo find() or get() throws an FPE. - forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false); - } else if (!param_.cudnn_tune.value()) { - cudnnConvolutionFwdAlgo_t fastest_fwd_algo; - CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used in dgrad - in_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_fwd_algo)); - forward_algo_.Set(fastest_fwd_algo, false); - } else { - cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - back_conv_desc_, // fwd algo used in dgrad - in_desc_, - kMaxAlgos, - &nalgo, - fwd_algo)); - i = 0; - while (i < nalgo - && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && fwd_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - fwd_algo[i].memory : - std::min(min_memory_needs, fwd_algo[i].memory); - } - if (i == nalgo) { - LOG(FATAL) << nalgo << " forward algorithms" - << " (for use in deconvolution operator backprop-to-data)" - << " with minimum memory requirement " << min_memory_needs - << " bytes have been tried. Workspace size is set to " << workspace_byte - << " bytes, please consider reducing the batch/model size," - << " or increasing workspace size."; - } else { - forward_algo_.Set(fwd_algo[i].algo, false); - } + kMaxAlgos, + &nalgo, + bwd_data_algo)); + i = 0; + while (i < nalgo + && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && bwd_data_algo[i].memory > workspace_byte))) { + ++i; + min_memory_needs = (i == 0) ? + bwd_data_algo[i].memory : + std::min(min_memory_needs, bwd_data_algo[i].memory); } - // Backprop-to-Filter Algorithm Find/Get, v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_filt_algo)); - back_algo_w_.Set(fastest_bwd_filt_algo, false); + if (i == nalgo) { + LogNoSuitableAlgoAndExit(nalgo, min_memory_needs, workspace_byte, + "backward data algos (for use in deconv op forward inference)"); } else { - cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - back_conv_desc_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo)); - i = 0; - while (i < nalgo - && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && bwd_filter_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_filter_algo[i].memory : - std::min(min_memory_needs, bwd_filter_algo[i].memory); - } - if (i == nalgo) { - LOG(FATAL) << nalgo << " backward filter algorithms" - << " (for use in deconvolution operator backprop-to-filter)" - << " with minimum memory requirement " << min_memory_needs - << " bytes have been tried. Workspace size is set to " << workspace_byte - << " bytes, please consider reducing the batch/model size," - << " or increasing workspace size."; - } else { - back_algo_w_.Set(bwd_filter_algo[i].algo, false); - } + bwd->Set(bwd_data_algo[i].algo, false); } - // Backprop-to-Data Algorithm Get(), v6 and earlier - if (!param_.cudnn_tune.value()) { - cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo; - CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used for inference - out_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &fastest_bwd_data_algo)); - back_algo_.Set(fastest_bwd_data_algo, false); + } +#endif // CUDNN_MAJOR < 7 + + // Fix for issue #11241 + int cudnn_find_issue_max_features = 64 * 1024; + // With deconvolution, the algo sensitivity is to a large number of output features + if (add_to_weight_ && Features(out_shape[deconv::kOut]) >= cudnn_find_issue_max_features) { + flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); + } + } + + void SelectAlgo(const RunContext& rctx, + const std::vector& in_shape, + const std::vector& out_shape, + cudnnDataType_t cudnn_forward_compute_type, + cudnnDataType_t cudnn_backward_compute_type) { + auto algo_setter = [&](CuDNNAlgo *fwd, + CuDNNAlgo *bwd, + CuDNNAlgo *flt) { + if (param_.cudnn_tune.value() == deconv::kOff) { + // The routine will only be calling cudnnGet, so no need to grab the Storage lock. + this->CuDNNAlgoSetter(rctx, in_shape, out_shape, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + fwd, bwd, flt); } else { - cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // bwd algo used in inference - out_desc_, - kMaxAlgos, - &nalgo, - bwd_data_algo)); - i = 0; - while (i < nalgo - && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS - || (param_.cudnn_tune.value() == deconv::kLimited - && bwd_data_algo[i].memory > workspace_byte))) { - ++i; - min_memory_needs = (i == 0) ? - bwd_data_algo[i].memory : - std::min(min_memory_needs, bwd_data_algo[i].memory); - } - if (i == nalgo) { - LOG(FATAL) << nalgo << " backward data algorithms" - << " (for use in deconvolution operator forward inference) with" - << " minimum memory requirement " << min_memory_needs - << " bytes have been tried. Workspace size is set to " << workspace_byte - << " bytes, please consider reducing the batch/model size," - << " or increasing workspace size."; - } else { - back_algo_.Set(bwd_data_algo[i].algo, false); - } - } - #endif // CUDNN_MAJOR < 7 + // One potential problem is that cudnnFind() uses cudaMalloc() to directly allocate + // I/O and workspace areas, and these allocations may result in an out-of-memory + // error even though the StorageMangager free pool is not empty. Ideally, cudnnFind + // would use MXNet's storage allocator for its I/O and workspace areas, instead of using + // the area carved out by MXNET_GPU_MEM_POOL_RESERVE. + // To get somewhat the same effect as this, we can pre-allocate the areas needed for the + // I/Os (possibly triggering a desirable StorageManager::ReleaseAll()), followed by a + // DirectFree(), which makes these areas available for cudnn's subsequent cudaMalloc(). + + // Allocate for x (or dx), w (or dw) and y (or dy). + ReserveElements({in_shape[conv::kData].Size(), + in_shape[conv::kWeight].Size(), + out_shape[conv::kOut].Size()}); - // Fix for issue #11241 - int cudnn_find_issue_max_features = 64 * 1024; - // With deconvolution, the algo sensitivity is to a large number of output features - if (add_to_weight_ && Features(out_shape[deconv::kOut]) >= cudnn_find_issue_max_features) { - this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); + // We're about to call cudnnFind so we need to quiet the system by grabbing + // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate timing + // measurements of the algos, and can prevent the cuda driver's proper freeing + // of cudnnFind's internal temporary allocations. Grabbing the lock might also + // impede other threads from launching work on the GPU. + std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); + this->CuDNNAlgoSetter(rctx, in_shape, out_shape, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + fwd, bwd, flt); } + }; + + // An algo specification by the user may be cached here, but another + // convolution will match only if identically specified. + // We're caching results of *Get* as well as *Find*, but these records + // will be held distinctly because param_.cudnn_tune is part of the key. + CuDNNDeconvAlgoReg::Get()->FindOrElseRegister(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(rctx.ctx.dev_id), add_to_weight_, + &forward_algo_, &back_algo_, &back_algo_w_, algo_setter); - // An algo specification by the user may be cached here, but another - // convolution will match only if identically specified. - // We're caching results of *Get* as well as *Find*, but these records - // will be held distinctly because param_.cudnn_tune is part of the key. - CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(rctx.ctx.dev_id), this->add_to_weight_, - this->forward_algo_, - this->back_algo_, this->back_algo_w_); - } // If we're allowing Tensor Core variants of the algos to be considered in // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, // we must change the descriptor to preclude Tensor Core. Simplest is to @@ -818,6 +847,7 @@ class CuDNNDeconvolutionOp { << " please consider reducing batch/model size or increasing the workspace size"; } + void GetTempSize(const OpContext& ctx) { mshadow::Stream *s = ctx.get_stream(); size_t back_data_algo_workspace_size = 0; @@ -921,6 +951,26 @@ class CuDNNDeconvolutionOp { return c; } + // Make a number of allocations and directly free them, ensuring room for an equivalent set of + // cudaMalloc() calls by (say) cudnnFind(). `elements` spec the alloc size in DTypes, not bytes. + void ReserveElements(const std::vector &elements) { + std::vector handles; + for (size_t alloc_element : elements) + handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU())); + for (auto &handle : handles) + Storage::Get()->DirectFree(handle); + } + + + // Log that no suitable algo was found that met the workspace constraints, then exit. + void LogNoSuitableAlgoAndExit(int num_algos_tried, size_t min_memory_needs, + size_t workspace_byte, std::string algo_kind) { + LOG(FATAL) << num_algos_tried << " " << algo_kind << " with minimum memory requirement " + << min_memory_needs << " bytes have been tried. Workspace size is set to " + << workspace_byte << " bytes, please consider reducing the batch/model size, " + << "or increasing workspace size."; + } + std::vector param_stride_; std::vector param_dilate_; diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index f3a9b16cdd81..cade8d9495f4 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -57,6 +57,11 @@ class GPUPooledStorageManager final : public StorageManager { GPUPooledStorageManager() { reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5); page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096); + large_alloc_round_size_ = dmlc::GetEnv("MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE", 2 * 1024 * 1024); + if (large_alloc_round_size_ <= 0) { + LOG(FATAL) << "MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE cannot be set to a value <= 0, found: " + << large_alloc_round_size_; + } if (page_size_ < NDEV) { LOG(FATAL) << "MXNET_GPU_MEM_POOL_PAGE_SIZE cannot be set to a value smaller than " << NDEV \ << ". Got " << page_size_ << "."; @@ -80,7 +85,7 @@ class GPUPooledStorageManager final : public StorageManager { private: void DirectFreeNoLock(Storage::Handle handle) { cudaError_t err = cudaFree(handle.dptr); - size_t size = std::max(handle.size, page_size_); + size_t size = RoundAllocSize(handle.size); // ignore unloading error, as memory has already been recycled if (err != cudaSuccess && err != cudaErrorCudartUnloading) { LOG(FATAL) << "CUDA: " << cudaGetErrorString(err); @@ -88,12 +93,31 @@ class GPUPooledStorageManager final : public StorageManager { used_memory_ -= size; } + // Round a value 'x' up to the next multiple of 'multiple' + size_t RoundToMultiple(size_t x, size_t multiple) { + size_t retVal = ((x + multiple - 1) / multiple) * multiple; + return retVal; + } + + size_t RoundAllocSize(size_t size) { + // Round up small allocs to the page_size_ to consolidate the pool lookups + size = std::max(size, page_size_); + // To ensure proper freeing under some driver variants, make sure + // large allocs entirely occupy their slabs, which cannot then be + // locked by smaller permanent allocations sharing the slab. + if (size > large_alloc_round_size_) + size = RoundToMultiple(size, large_alloc_round_size_); + return size; + } + private: void ReleaseAll(); // used memory size_t used_memory_ = 0; // page size size_t page_size_; + // size that large allocations should be rounded to, for proper freeing. + size_t large_alloc_round_size_; // percentage of reserved memory int reserve_; // number of devices @@ -105,7 +129,7 @@ class GPUPooledStorageManager final : public StorageManager { void GPUPooledStorageManager::Alloc(Storage::Handle* handle) { std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); - size_t size = std::max(handle->size, page_size_); + size_t size = RoundAllocSize(handle->size); auto&& reuse_it = memory_pool_.find(size); if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { size_t free, total; @@ -130,7 +154,7 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) { void GPUPooledStorageManager::Free(Storage::Handle handle) { std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); - size_t size = std::max(handle.size, page_size_); + size_t size = RoundAllocSize(handle.size); auto&& reuse_pool = memory_pool_[size]; reuse_pool.push_back(handle.dptr); } diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 8394276c8ef0..8ada95b3bf24 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -25,12 +25,14 @@ import mxnet as mx import numpy as np import unittest +import math from nose.tools import assert_raises from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal from mxnet.base import MXNetError from mxnet import autograd from numpy.testing import assert_allclose + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled @@ -57,7 +59,7 @@ def check_rnn_layer(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) - +@with_seed() def check_rnn_layer_w_rand_inputs(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) x = mx.nd.uniform(shape=(10, 16, 30)) @@ -186,7 +188,7 @@ def _syncParameters(bn1, bn2, ctx): input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) - +@with_seed() def test_sync_batchnorm(): def get_num_devices(): for i in range(100): @@ -203,6 +205,7 @@ def get_num_devices(): _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), num_devices=ndev, cuda=True) + @with_seed() def test_symbol_block_fp16(): # Test case to verify if initializing the SymbolBlock from a model with params @@ -233,6 +236,45 @@ def test_symbol_block_fp16(): break assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16) + +@with_seed() +def test_large_models(): + ctx = default_context() + # Create model + net = gluon.nn.HybridSequential() + + largest_num_features = 256 + with net.name_scope(): + net.add(nn.Conv2D(largest_num_features, 3)) + + net.hybridize() + net.initialize(mx.init.Normal(sigma=0.01), ctx=ctx) + + # Compute the height (=width) of the square tensor of the given size in bytes + def tensor_size(big_tensor_bytes): + bytes_per_float = 4 + sz = int(math.sqrt(big_tensor_bytes / largest_num_features / bytes_per_float)) + return (sz // 100) * 100 + + # The idea is to create models with large tensors of (say) 20% of the total memory. + # This in the past has given cudnnFind() trouble when it needed to allocate similar I/O's + # from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by default 5%). + (free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id) + start_size = tensor_size(0.20 * total_mem_bytes) + num_trials = 10 + sys.stderr.write(' testing global memory of size {} ... '.format(total_mem_bytes)) + sys.stderr.flush() + for i in range(num_trials): + sz = start_size - 10 * i + (height, width) = (sz,sz) + sys.stderr.write(" {}x{} ".format(height,width)) + sys.stderr.flush() + data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height, width), + ctx=ctx, dtype="float32") + # Evaluate model + net(data_in).asnumpy() + + if __name__ == '__main__': import nose nose.runmodule()