Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add exception handling support for waitall #14397

Merged
merged 25 commits into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a85b3f0
Relax constexpr restriction
anirudh2290 Feb 5, 2019
5bd6428
Change imagenet_gen_qsym_mkldnn
anirudh2290 Feb 12, 2019
d684e4c
Merge branch 'master' of https://github.com/apache/incubator-mxnet
anirudh2290 Feb 12, 2019
5debcc2
Merge branch 'master' of https://github.com/apache/incubator-mxnet
anirudh2290 Feb 26, 2019
082a0aa
Merge branch 'master' of https://github.com/apache/incubator-mxnet
anirudh2290 Mar 1, 2019
f194aa2
Merge branch 'master' of https://github.com/apache/incubator-mxnet
anirudh2290 Mar 4, 2019
26079e6
Add exception handling support for waitall
anirudh2290 Mar 12, 2019
f0a76e3
Fix exception handling documentation
anirudh2290 Mar 12, 2019
c169a86
Fix quantization file
anirudh2290 Mar 12, 2019
cd98fa9
Revert constexpr change
anirudh2290 Mar 12, 2019
4f694f6
Add comments
anirudh2290 Mar 12, 2019
d79560b
Fix test
anirudh2290 Mar 12, 2019
3a581e8
Skip exception for op check names
anirudh2290 Mar 12, 2019
0b5444b
Print exceptions thrown for CPP Package NDArray module
anirudh2290 Mar 12, 2019
38b8dca
Reducing batch_size to make cpp-package example pass
anirudh2290 Mar 14, 2019
1c0d936
Fix bug: #14426
anirudh2290 Mar 14, 2019
034e9c7
use ExceptionRef in threaded_engine code
anirudh2290 Mar 14, 2019
d34d95e
add note for performance impact of waitall
anirudh2290 Mar 14, 2019
48a6638
Add check for GPU contxt
anirudh2290 Mar 20, 2019
f631d57
Use range for with const reference
anirudh2290 Mar 21, 2019
54e301f
Improve comments and error message for exception handling test
anirudh2290 Mar 21, 2019
9e6972a
Change exception_ptr name in waitall
anirudh2290 Mar 21, 2019
ee1fabd
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Mar 22, 2019
ff8151c
Fix bug
anirudh2290 Mar 25, 2019
2a13065
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Apr 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp-package/example/resnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ int main(int argc, char const *argv[]) {
#if !MXNET_USE_CPU
if (num_gpu > 0) {
ctx = Context::gpu();
batch_size = 50;
batch_size = 32;
}
#endif

Expand Down
6 changes: 3 additions & 3 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape) const {
return NDArray(handle);
}
inline void NDArray::WaitToRead() const {
CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0);
CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitToWrite() {
CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0);
CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); }
inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) << MXGetLastError(); }
inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) {
Operator("_random_normal")(mu, sigma).Invoke(*out);
}
Expand Down
3 changes: 0 additions & 3 deletions docs/architecture/exception_handling.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,3 @@ except mx.base.MXNetError as ex:
d.asnumpy()
```

### Limitation
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your code executes a few operators and then calls `waitall` instead of `wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls in your code unless you are confident about your code not throwing exception in any scenario.
7 changes: 2 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,9 @@ def waitall():

This function is used for benchmarking only.

.. warning::
.. note::

If your code has exceptions, `waitall` can cause silent failures.
For this reason you should avoid `waitall` in your code.
Use it only if you are confident that your code is error free.
Then make sure you call `wait_to_read` on all outputs after `waitall`.
If your mxnet code throws an exception, then waitall can cause performance impact.
"""
check_call(_LIB.MXNDArrayWaitAll())

Expand Down
20 changes: 20 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,23 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
});
std::exception_ptr exception_to_rethrow = nullptr;
if (!global_exception_refs_.empty()) {
// iterate through all exception refs
for (const auto& global_exception_ref : global_exception_refs_) {
// the first exception will be saved to be rethrown later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the order of exceptions stored in the "global_exception_refs_" ? If we are throwing the first one then is it the innermost in the stack that causes all other exceptions or the outermost ? If its outermost then it might not give correct idea about what was the root cause

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@access2rohit the order of the exceptions will be maintained exception thrown first will be rethrown first.

if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr) {
exception_to_rethrow = *global_exception_ref;
}
// clear exceptions, WaitToRead following WaitForAll shouldn't throw
*global_exception_ref = nullptr;
}
// A waitall following a waitall shouldn't throw any exceptions
global_exception_refs_.clear();
if (exception_to_rethrow != nullptr) {
std::rethrow_exception(exception_to_rethrow);
}
}
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
Expand All @@ -428,6 +445,9 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
// add current operator exceptions to global exceptions if not already
// added
AddToGlobalExceptions(threaded_opr->opr_exception);
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
Expand Down
36 changes: 32 additions & 4 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ namespace engine {
// Forward declarations
struct ThreadedOpr;

/*! shared_ptr to exception_ptr, used for exception handling */
typedef std::shared_ptr<std::exception_ptr> ExceptionRef;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to wrap it in a shared_ptr? exception_ptr has already shared ptr semantics according to cppreference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the name Ref to me is confusing, why not Ptr? why add a suffix of the type at all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception_ptr cannot be dereferenced , so we cannot update the exception object it is pointing to or make it nullptr. Since this is a requirement for us we wrapped it in a shared_ptr. Used ref to make it consistent with other places in MXNet.


/*!
* \brief Operation block in the scheduler.
* Each OprBlock corresponds to an operation pushed to the engine.
Expand Down Expand Up @@ -177,8 +180,12 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::shared_ptr<std::exception_ptr> var_exception;
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
* exception_ptr and overcome this limitation */
ExceptionRef var_exception;

private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -254,8 +261,12 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::shared_ptr<std::exception_ptr> opr_exception;
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef opr_exception;
}; // struct ThreadedOpr

/*!
Expand Down Expand Up @@ -432,6 +443,7 @@ class ThreadedEngine : public Engine {
};
/*! thread local store for bulk */
typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;

/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
Expand Down Expand Up @@ -460,13 +472,15 @@ class ThreadedEngine : public Engine {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
Expand All @@ -475,6 +489,18 @@ class ThreadedEngine : public Engine {

static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
/*!
* \brief find exception in global_exception_refs and add it if missing
* \param opr_exception the exception to be added to global_exception_refs
*/
inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) {
auto it = std::find(global_exception_refs_.begin(),
global_exception_refs_.end(), opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(opr_exception);
}
return;
}
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down Expand Up @@ -542,6 +568,8 @@ class ThreadedEngine : public Engine {
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief global exception refs, which are rethrown when WaitForAll is called */
std::vector<ExceptionRef> global_exception_refs_;

/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
Expand Down
14 changes: 8 additions & 6 deletions src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ class ResourceManagerImpl : public ResourceManager {
cpu_rand_->Seed(seed);
cpu_parallel_rand_->Seed(seed);
#if MXNET_USE_CUDA
gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceRandom<gpu>(ctx, seed);
})->Seed(seed);
gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, seed);
})->Seed(seed);
if (ctx.dev_type == Context::kGPU) {
gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceRandom<gpu>(ctx, seed);
})->Seed(seed);
gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, seed);
})->Seed(seed);
}
#endif
}

Expand Down
113 changes: 81 additions & 32 deletions tests/python/unittest/test_exc_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def imperative(exec_numpy=True):
c.asnumpy()

imperative(exec_numpy=False)
assert_raises(MXNetError, imperative, True)
assert_raises(MXNetError, imperative, exec_numpy=True)

@with_seed()
def test_exc_symbolic():
def symbolic(exec_backward=True):
def symbolic(exec_backward=True, waitall=True):
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = mx.sym.Variable('z')
Expand All @@ -58,16 +58,25 @@ def symbolic(exec_backward=True):
outputs = exec1.forward()
if exec_backward:
exec1.backward()
exec1.grad_arrays[0].asnumpy()
if waitall:
mx.nd.waitall()
else:
exec1.grad_arrays[0].asnumpy()
else:
outputs[0].asnumpy()
if waitall:
mx.nd.waitall()
else:
outputs[0].asnumpy()

assert_raises(MXNetError, symbolic, False)
assert_raises(MXNetError, symbolic, True)
assert_raises(MXNetError, symbolic, exec_backward=False)
assert_raises(MXNetError, symbolic, exec_backward=True)

assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True)
assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True)

@with_seed()
def test_exc_gluon():
def gluon(exec_wait=True):
def gluon(exec_wait=True, waitall=False):
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
model.add(nn.Dropout(1))
Expand All @@ -77,46 +86,86 @@ def gluon(exec_wait=True):
y = model(x)
model.collect_params().initialize(ctx=[default_context()])
z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context()))
if exec_wait:
if waitall:
mx.nd.waitall()
elif exec_wait:
z.wait_to_read()

gluon(exec_wait=False)
assert_raises(MXNetError, gluon, True)
assert_raises(MXNetError, gluon, exec_wait=True)

assert_raises(MXNetError, gluon, waitall=True)

@with_seed()
def test_exc_multiple_waits():
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown"
def multiple_waits(waitall=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to use "@raises"? maybe it would be easier to read.

https://nose.readthedocs.io/en/latest/testing_tools.html

At least a small comment explaining the test approach for future readers and that we expect exception to be thrown, is that the intent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added comments. Intention is to test multiple wait_to_reads and waitalls for vars in same scope.

# Test calling failed op followed by wait_to_read or waitall twice
# Intention is to test rethrow for multiple wait_to_reads and waitalls
# for vars with exceptions in same scope
caught = False
try:
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
a.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall"
try:
b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
b.wait_to_read()
except MXNetError:
caught = True
assert caught, "No exception thrown, exception should be rethrown with wait_to_read/waitall"

multiple_waits(waitall=False)
multiple_waits(waitall=True)

@with_seed()
def test_exc_post_fail():
def post_fail(waitall=False):
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
post_fail(waitall=False)
post_fail(waitall=True)

@with_seed()
def test_exc_mutable_var_fail():
def mutable_var_check(waitall=False):
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a = mx.nd.dot(a, a)
if waitall:
mx.nd.waitall()
else:
a.asnumpy()
assert_raises(MXNetError, mutable_var_check, waitall=False)
assert_raises(MXNetError, mutable_var_check, waitall=True)

@with_seed()
def test_multiple_waitalls():
caught = False
try:
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a.asnumpy()
a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
mx.nd.waitall()
except MXNetError:
caught = True
assert caught, "No exception thrown"
b.asnumpy()
mx.nd.waitall()


@with_seed()
def test_exc_mutable_var_fail():
def mutable_var_check():
a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
a = mx.nd.dot(a, a)
a.asnumpy()
assert_raises(MXNetError, mutable_var_check)

if __name__ == '__main__':
import nose
Expand Down
14 changes: 12 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7131,7 +7131,12 @@ def get_output_names_callback(name, arr):

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback, monitor_all=False)
op_exe.forward()
try:
op_exe.forward()
mx.nd.waitall()
except mx.base.MXNetError:
# skip errors since test is to check output names
pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down Expand Up @@ -7177,7 +7182,12 @@ def get_output_names_callback(name, arr):

op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback, monitor_all=True)
op_exe.forward()
try:
op_exe.forward()
mx.nd.waitall()
except mx.base.MXNetError:
# skip errors since test is to check all names
pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down