From a85b3f00aa62e3b7f75110b5cdac418f87309c3d Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 5 Feb 2019 23:12:50 +0000 Subject: [PATCH 01/18] Relax constexpr restriction --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cc7aebc9fb89..ccc3c7828ba9 100644 --- a/Makefile +++ b/Makefile @@ -89,7 +89,7 @@ CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS) ifeq ($(DEV), 1) CFLAGS += -g -Werror - NVCCFLAGS += -Werror cross-execution-space-call + NVCCFLAGS += -Werror cross-execution-space-call --expt-relaxed-constexpr endif # CFLAGS for debug From 5bd642846d5c5885c9c1096caeff86629ad85110 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Feb 2019 23:05:39 +0000 Subject: [PATCH 02/18] Change imagenet_gen_qsym_mkldnn --- example/quantization/imagenet_gen_qsym_mkldnn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 938890bb75df..18e6e5d77eeb 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -179,7 +179,9 @@ def save_params(fname, arg_params, aux_params, logger=None): prefix, epoch = download_model(model_name=args.model, logger=logger) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + sym.save("model_pre_quantize.json") sym = sym.get_backend_symbol('MKLDNN') + sym.save("model_post_quantize.json") # get batch size batch_size = args.batch_size @@ -314,6 +316,7 @@ def save_params(fname, arg_params, aux_params, logger=None): % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + qsym.save("model_post_post_quantize.json") save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) From 26079e693b795579e9393e3b514aff7cfcde46e2 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 00:43:14 +0000 Subject: [PATCH 03/18] Add exception handling support for waitall --- src/engine/threaded_engine.cc | 27 +++++ src/engine/threaded_engine.h | 29 +++++- tests/python/unittest/test_exc_handling.py | 110 +++++++++++++++------ 3 files changed, 132 insertions(+), 34 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index b5897a1ca9cd..3b3ecfb349c6 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -415,6 +415,25 @@ void ThreadedEngine::WaitForAll() { finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); }); + std::exception_ptr tmp; + if (!global_exception_refs_.empty()) { + // iterate through all exception refs + for (auto itr = global_exception_refs_.begin(); + itr != global_exception_refs_.end(); ++itr) { + const std::shared_ptr& ptr = *itr; + // the first exception will be saved to be rethrown later + if (*ptr != nullptr && !tmp) { + tmp = *ptr; + } + // clear exceptions, WaitToRead following WaitForAll shouldn't throw + *ptr = nullptr; + } + // A waitall following a waitall shouldn't throw any exceptions + global_exception_refs_.clear(); + if (tmp != nullptr) { + std::rethrow_exception(tmp); + } + } } inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { @@ -428,6 +447,14 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { for (auto&& i : threaded_opr->mutable_vars) { if (threaded_opr->opr_exception && *threaded_opr->opr_exception) { i->var_exception = threaded_opr->opr_exception; + // add current operator exceptions to global exceptions if not already + // added + auto it = std::find(global_exception_refs_.begin(), + global_exception_refs_.end(), + threaded_opr->opr_exception); + if (it == global_exception_refs_.end()) { + global_exception_refs_.push_back(threaded_opr->opr_exception); + } } const bool debug_info = (engine_info_ && debug_wait_var_ == i); if (debug_info) { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index ab06ca1b9b47..de5e2cda7dac 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -176,7 +176,11 @@ class ThreadedVar final static std::atomic counter; ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } #endif // ENGINE_DEBUG - /*! \brief exception_ptr associated with the ThreadedVar */ + /*! + * \brief exception_ptr associated with the ThreadedOpr + * cannot modify state of exception object since dereferencing + * exception_ptr is undefined behavior. Using shared_ptr to hold + * exception_ptr and overcome this limitation*/ std::shared_ptr var_exception; private: @@ -253,7 +257,11 @@ struct ThreadedOpr final : public Opr, } // define possible debug information DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); - /*! \brief exception_ptr associated with the ThreadedOpr */ + /*! + * \brief exception_ptr associated with the ThreadedOpr + * cannot modify state of exception object since dereferencing + * exception_ptr is undefined behavior. Using shared_ptr to hold + * exception_ptr and overcome this limitation*/ std::shared_ptr opr_exception; }; // struct ThreadedOpr @@ -429,6 +437,9 @@ class ThreadedEngine : public Engine { }; /*! thread local store for bulk */ typedef dmlc::ThreadLocalStore BulkStatusStore; + /*! shared_ptr to exception_ptr, used for exception handling */ + typedef std::shared_ptr ExceptionRef; + /*! * \brief check if thee is duplication in const_vars and mutable_vars. * \param const_vars the variables to read from. @@ -457,6 +468,12 @@ class ThreadedEngine : public Engine { for (auto&& i : threaded_opr->const_vars) { if (i->var_exception && *i->var_exception) { threaded_opr->opr_exception = i->var_exception; + auto it = std::find(global_exception_refs_.begin(), + global_exception_refs_.end(), + threaded_opr->opr_exception); + if (it == global_exception_refs_.end()) { + global_exception_refs_.push_back(threaded_opr->opr_exception); + } break; } } @@ -464,6 +481,12 @@ class ThreadedEngine : public Engine { for (auto&& i : threaded_opr->mutable_vars) { if (i->var_exception && *i->var_exception) { threaded_opr->opr_exception = i->var_exception; + auto it = std::find(global_exception_refs_.begin(), + global_exception_refs_.end(), + threaded_opr->opr_exception); + if (it == global_exception_refs_.end()) { + global_exception_refs_.push_back(threaded_opr->opr_exception); + } break; } } @@ -539,6 +562,8 @@ class ThreadedEngine : public Engine { */ std::mutex finished_m_; std::condition_variable finished_cv_; + /*! \brief global exception refs, which are rethrown when WaitForAll is called*/ + std::vector global_exception_refs_; /*! * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early diff --git a/tests/python/unittest/test_exc_handling.py b/tests/python/unittest/test_exc_handling.py index e9e161d7f3b6..b7b8aa6c1d48 100644 --- a/tests/python/unittest/test_exc_handling.py +++ b/tests/python/unittest/test_exc_handling.py @@ -34,11 +34,11 @@ def imperative(exec_numpy=True): c.asnumpy() imperative(exec_numpy=False) - assert_raises(MXNetError, imperative, True) + assert_raises(MXNetError, imperative, exec_numpy=True) @with_seed() def test_exc_symbolic(): - def symbolic(exec_backward=True): + def symbolic(exec_backward=True, waitall=True): x = mx.sym.Variable('x') y = mx.sym.Variable('y') z = mx.sym.Variable('z') @@ -58,16 +58,25 @@ def symbolic(exec_backward=True): outputs = exec1.forward() if exec_backward: exec1.backward() - exec1.grad_arrays[0].asnumpy() + if waitall: + mx.nd.waitall() + else: + exec1.grad_arrays[0].asnumpy() else: - outputs[0].asnumpy() + if waitall: + mx.nd.waitall() + else: + outputs[0].asnumpy() - assert_raises(MXNetError, symbolic, False) - assert_raises(MXNetError, symbolic, True) + assert_raises(MXNetError, symbolic, exec_backward=False) + assert_raises(MXNetError, symbolic, exec_backward=True) + + assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True) + assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True) @with_seed() def test_exc_gluon(): - def gluon(exec_wait=True): + def gluon(exec_wait=True, waitall=False): model = nn.Sequential() model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False)) model.add(nn.Dropout(1)) @@ -77,46 +86,83 @@ def gluon(exec_wait=True): y = model(x) model.collect_params().initialize(ctx=[default_context()]) z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context())) - if exec_wait: + if waitall: + mx.nd.waitall() + elif exec_wait: z.wait_to_read() gluon(exec_wait=False) - assert_raises(MXNetError, gluon, True) + assert_raises(MXNetError, gluon, exec_wait=True) + + assert_raises(MXNetError, gluon, waitall=True) @with_seed() def test_exc_multiple_waits(): - caught = False - try: - a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) - a.wait_to_read() - except MXNetError: - caught = True - assert caught, "No exception thrown" - try: - b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) - b.wait_to_read() - except MXNetError: - caught = True - assert caught, "No exception thrown" + def multiple_waits(waitall=False): + caught = False + try: + a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + if waitall: + mx.nd.waitall() + else: + a.wait_to_read() + except MXNetError: + caught = True + assert caught, "No exception thrown" + try: + b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + if waitall: + mx.nd.waitall() + else: + b.wait_to_read() + except MXNetError: + caught = True + assert caught, "No exception thrown" + + multiple_waits(waitall=False) + multiple_waits(waitall=True) @with_seed() def test_exc_post_fail(): + def post_fail(waitall=False): + caught = False + try: + a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) + if waitall: + mx.nd.waitall() + else: + a.asnumpy() + except MXNetError: + caught = True + assert caught, "No exception thrown" + b.asnumpy() + post_fail(waitall=False) + post_fail(waitall=True) + +@with_seed() +def test_exc_mutable_var_fail(): + def mutable_var_check(waitall=False): + a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) + a = mx.nd.dot(a, a) + if waitall: + mx.nd.waitall() + else: + a.asnumpy() + assert_raises(MXNetError, mutable_var_check, waitall=False) + assert_raises(MXNetError, mutable_var_check, waitall=True) + +@with_seed() +def test_multiple_waitalls(): caught = False try: - a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) - a.asnumpy() + a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + mx.nd.waitall() except MXNetError: caught = True assert caught, "No exception thrown" - b.asnumpy() + mx.nd.waitall() + -@with_seed() -def test_exc_mutable_var_fail(): - def mutable_var_check(): - a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) - a = mx.nd.dot(a, a) - a.asnumpy() - assert_raises(MXNetError, mutable_var_check) if __name__ == '__main__': import nose From f0a76e3ab1a76707fc890a5f1b32fd798f8fb33f Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 02:04:15 +0000 Subject: [PATCH 04/18] Fix exception handling documentation --- docs/architecture/exception_handling.md | 3 --- python/mxnet/ndarray/ndarray.py | 7 ------- 2 files changed, 10 deletions(-) diff --git a/docs/architecture/exception_handling.md b/docs/architecture/exception_handling.md index 6a9ab9ae0c4c..87481bcdb9bd 100644 --- a/docs/architecture/exception_handling.md +++ b/docs/architecture/exception_handling.md @@ -123,6 +123,3 @@ except mx.base.MXNetError as ex: d.asnumpy() ``` -### Limitation - -Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your code executes a few operators and then calls `waitall` instead of `wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls in your code unless you are confident about your code not throwing exception in any scenario. diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index fb329f1865a9..828cfa6b1013 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -157,13 +157,6 @@ def waitall(): """Wait for all async operations to finish in MXNet. This function is used for benchmarking only. - - .. warning:: - - If your code has exceptions, `waitall` can cause silent failures. - For this reason you should avoid `waitall` in your code. - Use it only if you are confident that your code is error free. - Then make sure you call `wait_to_read` on all outputs after `waitall`. """ check_call(_LIB.MXNDArrayWaitAll()) From cd98fa9b771510fa541bda869ed0b55015aa5b72 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 02:08:27 +0000 Subject: [PATCH 05/18] Revert constexpr change --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index c38abe1f93a2..29443eba2275 100644 --- a/Makefile +++ b/Makefile @@ -89,7 +89,7 @@ CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS) ifeq ($(DEV), 1) CFLAGS += -g -Werror - NVCCFLAGS += -Werror cross-execution-space-call --expt-relaxed-constexpr + NVCCFLAGS += -Werror cross-execution-space-call endif # CFLAGS for debug From 4f694f66680914292d91c4cc1107b4bc0c07c742 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 05:30:40 +0000 Subject: [PATCH 06/18] Add comments --- src/engine/threaded_engine.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 55443f8a80c7..cd9e4ca0c482 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -181,7 +181,7 @@ class ThreadedVar final * \brief exception_ptr associated with the ThreadedOpr * cannot modify state of exception object since dereferencing * exception_ptr is undefined behavior. Using shared_ptr to hold - * exception_ptr and overcome this limitation*/ + * exception_ptr and overcome this limitation */ std::shared_ptr var_exception; private: @@ -262,7 +262,7 @@ struct ThreadedOpr final : public Opr, * \brief exception_ptr associated with the ThreadedOpr * cannot modify state of exception object since dereferencing * exception_ptr is undefined behavior. Using shared_ptr to hold - * exception_ptr and overcome this limitation*/ + * exception_ptr and overcome this limitation */ std::shared_ptr opr_exception; }; // struct ThreadedOpr @@ -565,7 +565,7 @@ class ThreadedEngine : public Engine { */ std::mutex finished_m_; std::condition_variable finished_cv_; - /*! \brief global exception refs, which are rethrown when WaitForAll is called*/ + /*! \brief global exception refs, which are rethrown when WaitForAll is called */ std::vector global_exception_refs_; /*! From d79560bf77cb056e12caea2a91a0109dfe1fcfd7 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 05:50:56 +0000 Subject: [PATCH 07/18] Fix test --- tests/python/unittest/test_operator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7169395205e0..549bae6d1507 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6961,7 +6961,12 @@ def get_output_names_callback(name, arr): op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') op_exe.set_monitor_callback(get_output_names_callback, monitor_all=False) - op_exe.forward() + try: + op_exe.forward() + mx.nd.waitall() + except mx.base.MXNetError: + # skip errors since test is to check output names + pass for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name From 3a581e848e2a100b2aa0cc52404d1ab809bb64b4 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 16:04:04 +0000 Subject: [PATCH 08/18] Skip exception for op check names --- tests/python/unittest/test_operator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 549bae6d1507..adae67cf17f6 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7012,7 +7012,12 @@ def get_output_names_callback(name, arr): op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null') op_exe.set_monitor_callback(get_output_names_callback, monitor_all=True) - op_exe.forward() + try: + op_exe.forward() + mx.nd.waitall() + except mx.base.MXNetError: + # skip errors since test is to check all names + pass for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name From 0b5444b6162b70c5be6276dcb680904a5f1b8345 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 12 Mar 2019 18:14:45 +0000 Subject: [PATCH 09/18] Print exceptions thrown for CPP Package NDArray module --- cpp-package/include/mxnet-cpp/ndarray.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 966cf75c9122..b667542bffb5 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape) const { return NDArray(handle); } inline void NDArray::WaitToRead() const { - CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0); + CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError(); } inline void NDArray::WaitToWrite() { - CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0); + CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError(); } -inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } +inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) << MXGetLastError(); } inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { Operator("_random_normal")(mu, sigma).Invoke(*out); } From 38b8dcad7a3eaa8717674a351badfeec9cecca5d Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 14 Mar 2019 00:04:33 +0000 Subject: [PATCH 10/18] Reducing batch_size to make cpp-package example pass --- cpp-package/example/resnet.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp index 29071bd72ee1..371565fb6e91 100644 --- a/cpp-package/example/resnet.cpp +++ b/cpp-package/example/resnet.cpp @@ -184,7 +184,7 @@ int main(int argc, char const *argv[]) { #if !MXNET_USE_CPU if (num_gpu > 0) { ctx = Context::gpu(); - batch_size = 50; + batch_size = 32; } #endif From 1c0d936896c6e153b1942d624030553b9cbed519 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 14 Mar 2019 00:08:23 +0000 Subject: [PATCH 11/18] Fix bug: #14426 --- src/resource.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/resource.cc b/src/resource.cc index 80a5c0e444e1..fe9da551ef79 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -190,12 +190,14 @@ class ResourceManagerImpl : public ResourceManager { cpu_rand_->Seed(seed); cpu_parallel_rand_->Seed(seed); #if MXNET_USE_CUDA - gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() { - return new ResourceRandom(ctx, seed); - })->Seed(seed); - gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() { - return new ResourceParallelRandom(ctx, gpu_native_rand_copy_, seed); - })->Seed(seed); + if (ctx != Context::CPU()) { + gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() { + return new ResourceRandom(ctx, seed); + })->Seed(seed); + gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() { + return new ResourceParallelRandom(ctx, gpu_native_rand_copy_, seed); + })->Seed(seed); + } #endif } From 034e9c7ed27ea08149d0850a7ad21bfb83f69ac9 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 14 Mar 2019 18:09:02 +0000 Subject: [PATCH 12/18] use ExceptionRef in threaded_engine code --- src/engine/threaded_engine.cc | 2 +- src/engine/threaded_engine.h | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 3b3ecfb349c6..7c3667dffbc6 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -420,7 +420,7 @@ void ThreadedEngine::WaitForAll() { // iterate through all exception refs for (auto itr = global_exception_refs_.begin(); itr != global_exception_refs_.end(); ++itr) { - const std::shared_ptr& ptr = *itr; + const ExceptionRef& ptr = *itr; // the first exception will be saved to be rethrown later if (*ptr != nullptr && !tmp) { tmp = *ptr; diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index cd9e4ca0c482..229ee383b1b1 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -60,6 +60,9 @@ namespace engine { // Forward declarations struct ThreadedOpr; +/*! shared_ptr to exception_ptr, used for exception handling */ +typedef std::shared_ptr ExceptionRef; + /*! * \brief Operation block in the scheduler. * Each OprBlock corresponds to an operation pushed to the engine. @@ -182,7 +185,7 @@ class ThreadedVar final * cannot modify state of exception object since dereferencing * exception_ptr is undefined behavior. Using shared_ptr to hold * exception_ptr and overcome this limitation */ - std::shared_ptr var_exception; + ExceptionRef var_exception; private: // TODO(hotpxl) change this to spinlock for faster runtime @@ -263,7 +266,7 @@ struct ThreadedOpr final : public Opr, * cannot modify state of exception object since dereferencing * exception_ptr is undefined behavior. Using shared_ptr to hold * exception_ptr and overcome this limitation */ - std::shared_ptr opr_exception; + ExceptionRef opr_exception; }; // struct ThreadedOpr /*! @@ -440,8 +443,6 @@ class ThreadedEngine : public Engine { }; /*! thread local store for bulk */ typedef dmlc::ThreadLocalStore BulkStatusStore; - /*! shared_ptr to exception_ptr, used for exception handling */ - typedef std::shared_ptr ExceptionRef; /*! * \brief check if thee is duplication in const_vars and mutable_vars. From d34d95e229065a37010c94912b6c03aabb8340c4 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 14 Mar 2019 18:16:47 +0000 Subject: [PATCH 13/18] add note for performance impact of waitall --- python/mxnet/ndarray/ndarray.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 7238c55a73f9..0ec8fdd7e91c 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -157,6 +157,10 @@ def waitall(): """Wait for all async operations to finish in MXNet. This function is used for benchmarking only. + + .. note:: + + If your mxnet code throws an exception, then waitall can cause performance impact. """ check_call(_LIB.MXNDArrayWaitAll()) From 48a6638d8f2b4cedc84e3bc2c8b7ad723e027966 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Wed, 20 Mar 2019 21:36:32 +0000 Subject: [PATCH 14/18] Add check for GPU contxt --- src/engine/threaded_engine.cc | 7 +------ src/engine/threaded_engine.h | 26 ++++++++++++++------------ src/resource.cc | 2 +- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 7c3667dffbc6..6e5c7ce72a82 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -449,12 +449,7 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { i->var_exception = threaded_opr->opr_exception; // add current operator exceptions to global exceptions if not already // added - auto it = std::find(global_exception_refs_.begin(), - global_exception_refs_.end(), - threaded_opr->opr_exception); - if (it == global_exception_refs_.end()) { - global_exception_refs_.push_back(threaded_opr->opr_exception); - } + AddToGlobalExceptions(threaded_opr->opr_exception); } const bool debug_info = (engine_info_ && debug_wait_var_ == i); if (debug_info) { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 229ee383b1b1..3d2119d63291 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -472,12 +472,7 @@ class ThreadedEngine : public Engine { for (auto&& i : threaded_opr->const_vars) { if (i->var_exception && *i->var_exception) { threaded_opr->opr_exception = i->var_exception; - auto it = std::find(global_exception_refs_.begin(), - global_exception_refs_.end(), - threaded_opr->opr_exception); - if (it == global_exception_refs_.end()) { - global_exception_refs_.push_back(threaded_opr->opr_exception); - } + AddToGlobalExceptions(threaded_opr->opr_exception); break; } } @@ -485,12 +480,7 @@ class ThreadedEngine : public Engine { for (auto&& i : threaded_opr->mutable_vars) { if (i->var_exception && *i->var_exception) { threaded_opr->opr_exception = i->var_exception; - auto it = std::find(global_exception_refs_.begin(), - global_exception_refs_.end(), - threaded_opr->opr_exception); - if (it == global_exception_refs_.end()) { - global_exception_refs_.push_back(threaded_opr->opr_exception); - } + AddToGlobalExceptions(threaded_opr->opr_exception); break; } } @@ -499,6 +489,18 @@ class ThreadedEngine : public Engine { static void OnCompleteStatic(Engine *engine, void *threaded_opr, const dmlc::Error* error); + /*! + * \brief find exception in global_exception_refs and add it if missing + * \param opr_exception the exception to be added to global_exception_refs + */ + inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) { + auto it = std::find(global_exception_refs_.begin(), + global_exception_refs_.end(), opr_exception); + if (it == global_exception_refs_.end()) { + global_exception_refs_.push_back(opr_exception); + } + return; + } /*! \brief append an operator to bulk */ inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, std::vector const& const_vars, diff --git a/src/resource.cc b/src/resource.cc index fe9da551ef79..403743d44325 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -190,7 +190,7 @@ class ResourceManagerImpl : public ResourceManager { cpu_rand_->Seed(seed); cpu_parallel_rand_->Seed(seed); #if MXNET_USE_CUDA - if (ctx != Context::CPU()) { + if (ctx.dev_type == Context::kGPU) { gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() { return new ResourceRandom(ctx, seed); })->Seed(seed); From f631d57a3a9e5f4e329d17c41930bbeda9824ffa Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 21 Mar 2019 20:38:37 +0000 Subject: [PATCH 15/18] Use range for with const reference --- src/engine/threaded_engine.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 6e5c7ce72a82..30d1e75f9a77 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -418,15 +418,13 @@ void ThreadedEngine::WaitForAll() { std::exception_ptr tmp; if (!global_exception_refs_.empty()) { // iterate through all exception refs - for (auto itr = global_exception_refs_.begin(); - itr != global_exception_refs_.end(); ++itr) { - const ExceptionRef& ptr = *itr; + for (const auto& global_exception_ref : global_exception_refs_) { // the first exception will be saved to be rethrown later - if (*ptr != nullptr && !tmp) { - tmp = *ptr; + if (*global_exception_ref != nullptr && tmp == false) { + tmp = *global_exception_ref; } // clear exceptions, WaitToRead following WaitForAll shouldn't throw - *ptr = nullptr; + *global_exception_ref = nullptr; } // A waitall following a waitall shouldn't throw any exceptions global_exception_refs_.clear(); From 54e301f8854211ad889ecebde7b95cf13ae090ff Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 21 Mar 2019 20:54:59 +0000 Subject: [PATCH 16/18] Improve comments and error message for exception handling test --- tests/python/unittest/test_exc_handling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_exc_handling.py b/tests/python/unittest/test_exc_handling.py index b7b8aa6c1d48..60799f821b8e 100644 --- a/tests/python/unittest/test_exc_handling.py +++ b/tests/python/unittest/test_exc_handling.py @@ -99,6 +99,9 @@ def gluon(exec_wait=True, waitall=False): @with_seed() def test_exc_multiple_waits(): def multiple_waits(waitall=False): + # Test calling failed op followed by wait_to_read or waitall twice + # Intention is to test rethrow for multiple wait_to_reads and waitalls + # for vars with exceptions in same scope caught = False try: a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) @@ -108,7 +111,7 @@ def multiple_waits(waitall=False): a.wait_to_read() except MXNetError: caught = True - assert caught, "No exception thrown" + assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall" try: b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) if waitall: @@ -117,7 +120,7 @@ def multiple_waits(waitall=False): b.wait_to_read() except MXNetError: caught = True - assert caught, "No exception thrown" + assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall" multiple_waits(waitall=False) multiple_waits(waitall=True) From 9e6972a3b9607473af711b5c202fcc8a0e6050ee Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 21 Mar 2019 21:28:50 +0000 Subject: [PATCH 17/18] Change exception_ptr name in waitall --- src/engine/threaded_engine.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 30d1e75f9a77..ef946fd8ecc8 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -415,21 +415,21 @@ void ThreadedEngine::WaitForAll() { finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); }); - std::exception_ptr tmp; + std::exception_ptr exception_to_rethrow; if (!global_exception_refs_.empty()) { // iterate through all exception refs for (const auto& global_exception_ref : global_exception_refs_) { // the first exception will be saved to be rethrown later - if (*global_exception_ref != nullptr && tmp == false) { - tmp = *global_exception_ref; + if (*global_exception_ref != nullptr && exception_to_rethrow == false) { + exception_to_rethrow = *global_exception_ref; } // clear exceptions, WaitToRead following WaitForAll shouldn't throw *global_exception_ref = nullptr; } // A waitall following a waitall shouldn't throw any exceptions global_exception_refs_.clear(); - if (tmp != nullptr) { - std::rethrow_exception(tmp); + if (exception_to_rethrow != nullptr) { + std::rethrow_exception(exception_to_rethrow); } } } From ff8151ce454cfc5b451a37b0d4f2c128b2d4a40e Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 25 Mar 2019 20:50:36 +0000 Subject: [PATCH 18/18] Fix bug --- src/engine/threaded_engine.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index ef946fd8ecc8..986b6ad29909 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -415,12 +415,12 @@ void ThreadedEngine::WaitForAll() { finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); }); - std::exception_ptr exception_to_rethrow; + std::exception_ptr exception_to_rethrow = nullptr; if (!global_exception_refs_.empty()) { // iterate through all exception refs for (const auto& global_exception_ref : global_exception_refs_) { // the first exception will be saved to be rethrown later - if (*global_exception_ref != nullptr && exception_to_rethrow == false) { + if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr) { exception_to_rethrow = *global_exception_ref; } // clear exceptions, WaitToRead following WaitForAll shouldn't throw