Skip to content

Commit

Permalink
Add exception handling support for waitall (apache#14397)
Browse files Browse the repository at this point in the history
* Relax constexpr restriction

* Change imagenet_gen_qsym_mkldnn

* Add exception handling support for waitall

* Fix exception handling documentation

* Revert constexpr change

* Add comments

* Fix test

* Skip exception for op check names

* Print exceptions thrown for CPP Package NDArray module

* Reducing batch_size to make cpp-package example pass

* Fix bug: apache#14426

* use ExceptionRef in threaded_engine code

* add note for performance impact of waitall

* Add check for GPU contxt

* Use range for with const reference

* Improve comments and error message for exception handling test

* Change exception_ptr name in waitall

* Fix bug
  • Loading branch information
anirudh2290 authored and haohuw committed Jun 23, 2019
1 parent d76817f commit 456ca1f
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 56 deletions.
2 changes: 1 addition & 1 deletion cpp-package/example/resnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ int main(int argc, char const *argv[]) {
#if !MXNET_USE_CPU
if (num_gpu > 0) {
ctx = Context::gpu();
batch_size = 50;
batch_size = 32;
}
#endif

Expand Down
6 changes: 3 additions & 3 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape) const {
return NDArray(handle);
}
inline void NDArray::WaitToRead() const {
CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0);
CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitToWrite() {
CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0);
CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); }
inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) << MXGetLastError(); }
inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) {
Operator("_random_normal")(mu, sigma).Invoke(*out);
}
Expand Down
3 changes: 0 additions & 3 deletions docs/architecture/exception_handling.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,3 @@ except mx.base.MXNetError as ex:
d.asnumpy()
```

### Limitation

Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your code executes a few operators and then calls `waitall` instead of `wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls in your code unless you are confident about your code not throwing exception in any scenario.
7 changes: 2 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,9 @@ def waitall():
This function is used for benchmarking only.
.. warning::
.. note::
If your code has exceptions, `waitall` can cause silent failures.
For this reason you should avoid `waitall` in your code.
Use it only if you are confident that your code is error free.
Then make sure you call `wait_to_read` on all outputs after `waitall`.
If your mxnet code throws an exception, then waitall can cause performance impact.
"""
check_call(_LIB.MXNDArrayWaitAll())

Expand Down
20 changes: 20 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,23 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
});
std::exception_ptr exception_to_rethrow = nullptr;
if (!global_exception_refs_.empty()) {
// iterate through all exception refs
for (const auto& global_exception_ref : global_exception_refs_) {
// the first exception will be saved to be rethrown later
if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr) {
exception_to_rethrow = *global_exception_ref;
}
// clear exceptions, WaitToRead following WaitForAll shouldn't throw
*global_exception_ref = nullptr;
}
// A waitall following a waitall shouldn't throw any exceptions
global_exception_refs_.clear();
if (exception_to_rethrow != nullptr) {
std::rethrow_exception(exception_to_rethrow);
}
}
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
Expand All @@ -428,6 +445,9 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
// add current operator exceptions to global exceptions if not already
// added
AddToGlobalExceptions(threaded_opr->opr_exception);
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
Expand Down
36 changes: 32 additions & 4 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ namespace engine {
// Forward declarations
struct ThreadedOpr;

/*! shared_ptr to exception_ptr, used for exception handling */
typedef std::shared_ptr<std::exception_ptr> ExceptionRef;

/*!
* \brief Operation block in the scheduler.
* Each OprBlock corresponds to an operation pushed to the engine.
Expand Down Expand Up @@ -177,8 +180,12 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::shared_ptr<std::exception_ptr> var_exception;
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef var_exception;

private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -254,8 +261,12 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::shared_ptr<std::exception_ptr> opr_exception;
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef opr_exception;
}; // struct ThreadedOpr

/*!
Expand Down Expand Up @@ -432,6 +443,7 @@ class ThreadedEngine : public Engine {
};
/*! thread local store for bulk */
typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;

/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
Expand Down Expand Up @@ -460,13 +472,15 @@ class ThreadedEngine : public Engine {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
Expand All @@ -475,6 +489,18 @@ class ThreadedEngine : public Engine {

static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
/*!
* \brief find exception in global_exception_refs and add it if missing
* \param opr_exception the exception to be added to global_exception_refs
*/
inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) {
auto it = std::find(global_exception_refs_.begin(),
global_exception_refs_.end(), opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(opr_exception);
}
return;
}
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down Expand Up @@ -542,6 +568,8 @@ class ThreadedEngine : public Engine {
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief global exception refs, which are rethrown when WaitForAll is called */
std::vector<ExceptionRef> global_exception_refs_;

/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
Expand Down
14 changes: 8 additions & 6 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ class ResourceManagerImpl : public ResourceManager {
cpu_rand_->Seed(seed);
cpu_parallel_rand_->Seed(seed);
#if MXNET_USE_CUDA
gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceRandom<gpu>(ctx, seed);
})->Seed(seed);
gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, seed);
})->Seed(seed);
if (ctx.dev_type == Context::kGPU) {
gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceRandom<gpu>(ctx, seed);
})->Seed(seed);
gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, seed);
})->Seed(seed);
}
#endif
}

Expand Down
113 changes: 81 additions & 32 deletions tests/python/unittest/test_exc_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def imperative(exec_numpy=True):
c.asnumpy()

imperative(exec_numpy=False)
assert_raises(MXNetError, imperative, True)
assert_raises(MXNetError, imperative, exec_numpy=True)

@with_seed()
def test_exc_symbolic():
def symbolic(exec_backward=True):
def symbolic(exec_backward=True, waitall=True):
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = mx.sym.Variable('z')
Expand All @@ -58,16 +58,25 @@ def symbolic(exec_backward=True):
outputs = exec1.forward()
if exec_backward:
exec1.backward()
exec1.grad_arrays[0].asnumpy()
if waitall:
mx.nd.waitall()
else:
exec1.grad_arrays[0].asnumpy()
else:
outputs[0].asnumpy()
if waitall:
mx.nd.waitall()
else:
outputs[0].asnumpy()

assert_raises(MXNetError, symbolic, False)
assert_raises(MXNetError, symbolic, True)
assert_raises(MXNetError, symbolic, exec_backward=False)
assert_raises(MXNetError, symbolic, exec_backward=True)

assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True)
assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True)

@with_seed()
def test_exc_gluon():
def gluon(exec_wait=True):
def gluon(exec_wait=True, waitall=False):
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
model.add(nn.Dropout(1))
Expand All @@ -77,46 +86,86 @@ def gluon(exec_wait=True):
y = model(x)
model.collect_params().initialize(ctx=[default_context()])
z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context()))
if exec_wait:
if waitall:
mx.nd.waitall()
elif exec_wait:
z.wait_to_read()

gluon(exec_wait=False)
assert_raises(MXNetError, gluon, True)
assert_raises(MXNetError, gluon, exec_wait=True)

assert_raises(MXNetError, gluon, waitall=True)

@with_seed()
def test_exc_multiple_waits():
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
def multiple_waits(waitall=False):
# Test calling failed op followed by wait_to_read or waitall twice
# Intention is to test rethrow for multiple wait_to_reads and waitalls
# for vars with exceptions in same scope
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall"

multiple_waits(waitall=False)
multiple_waits(waitall=True)

@with_seed()
def test_exc_post_fail():
def post_fail(waitall=False):
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
post_fail(waitall=False)
post_fail(waitall=True)

@with_seed()
def test_exc_mutable_var_fail():
def mutable_var_check(waitall=False):
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a = mx.nd.dot(a, a)
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
assert_raises(MXNetError, mutable_var_check, waitall=False)
assert_raises(MXNetError, mutable_var_check, waitall=True)

@with_seed()
def test_multiple_waitalls():
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a.asnumpy()
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
mx.nd.waitall()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
mx.nd.waitall()


@with_seed()
def test_exc_mutable_var_fail():
def mutable_var_check():
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a = mx.nd.dot(a, a)
a.asnumpy()
assert_raises(MXNetError, mutable_var_check)

if __name__ == '__main__':
import nose
Expand Down
14 changes: 12 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7164,7 +7164,12 @@ def get_output_names_callback(name, arr):

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback, monitor_all=False)
op_exe.forward()
try:
op_exe.forward()
mx.nd.waitall()
except mx.base.MXNetError:
# skip errors since test is to check output names
pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down Expand Up @@ -7210,7 +7215,12 @@ def get_output_names_callback(name, arr):

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback, monitor_all=True)
op_exe.forward()
try:
op_exe.forward()
mx.nd.waitall()
except mx.base.MXNetError:
# skip errors since test is to check all names
pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down

0 comments on commit 456ca1f

Please sign in to comment.