-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Better Exception Handling for Operators #9681
Conversation
This is going to be a big improvement for new users experimenting with the library. Thanks for the great work @anirudh2290. |
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
#Unless required by applicable law or agreed to in writing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: spacing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed.
include/mxnet/engine.h
Outdated
@@ -182,7 +182,7 @@ class MXNET_API Engine { | |||
std::vector<VarHandle> const& mutable_vars, | |||
FnProperty prop = FnProperty::kNormal, | |||
int priority = 0, | |||
const char* opr_name = nullptr) = 0; | |||
const char* opr_name = nullptr, bool wait = false) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's wait? Why do you need it? Please document arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait is used to indicate to the ExecuteOprBlock whether it is a waitforvar operation. it should not block the execution of the operator for WaitForVar.
src/engine/threaded_engine.cc
Outdated
if (threaded_var->ex_ptr) { | ||
std::rethrow_exception(threaded_var->ex_ptr); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you not returning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed.
src/engine/threaded_engine.cc
Outdated
@@ -391,6 +400,11 @@ void ThreadedEngine::WaitForAll() { | |||
finished_cv_.wait(lock, [this]() { | |||
return pending_.load() == 0 || kill_.load(); | |||
}); | |||
if (global_ex_ptr) { | |||
std::exception_ptr ex_ptr = global_ex_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use std::rethrow_exception(std::move(global_ex_ptr))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rolling this back, Since the state of global_ex_ptr is not guaranteed to be nullptr after the move and it depends on the implementation. This probably explains why it started failing on windows, after the change. Please let me know if you have any concerns.
src/engine/threaded_engine.cc
Outdated
i->ex_ptr = threaded_opr->ex_ptr; | ||
if (!global_ex_ptr) global_ex_ptr = i->ex_ptr; | ||
} | ||
bool debug_info = (engine_info_ && debug_wait_var_ == i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove const?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was not intentional. Added it back.
src/engine/threaded_engine.cc
Outdated
if (threaded_opr->ex_ptr) { | ||
i->ex_ptr = threaded_opr->ex_ptr; | ||
if (!global_ex_ptr) global_ex_ptr = i->ex_ptr; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suppose and operator has three outputs x, y and z and it raises an exception.
then x.asnumpy() would raise an error.
Then y.asnumpy() would raise the same error again.
and if I do z += 1 and it succeeds, z.asnumpy() would still raise the error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
z += 1 won't execute since z already has an exception associated with it. z.asnumpy() will still raise the error.
src/engine/threaded_engine.h
Outdated
@@ -175,6 +175,8 @@ class ThreadedVar final | |||
static std::atomic<std::size_t> counter; | |||
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } | |||
#endif // ENGINE_DEBUG | |||
/*! \brief exception_ptr associated with the ThreadedVar */ | |||
std::exception_ptr ex_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ex_ptr is a bad variable name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i changed ex_ptr for var to var_exception, ex_ptr for opr to opr_exception and the global exception_ptr global_exception_ . I welcome if you have any other suggestions for naming them.
src/engine/threaded_engine.h
Outdated
@@ -338,33 +346,46 @@ class ThreadedEngine : public Engine { | |||
#endif | |||
CallbackOnComplete callback = this->CreateCallback( | |||
ThreadedEngine::OnCompleteStatic, opr_block); | |||
CallbackOnComplete on_start_callback = this->CreateCallback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the point of creating a call back here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The callback is not strictly necessary here, since it is called only once, but I included it to keep the ExecuteOprBlock easier to read and separate out the OnStart logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is unnecessary overhead. Call OnStart directly if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Called OnStart directly.
} else { | ||
callback(); | ||
} | ||
} catch (dmlc::Error& e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only catch dmlc error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I catch only dmlc::Error is that The guards in the c_api API_BEGIN and API_END/API_END_HANDLE_ERROR only catch dmlc::Error currently and propagate to frontend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have used dmlc::Error in the inner block and used std::exception to catch other stdlib exceptions thrown in the outer block. Currently, it will catch only dmlc::Error and for other exceptions(std::exception) the process will be terminated. I will open another PR to handle std::exception and change the c_api guards and frontend code.
include/mxnet/engine.h
Outdated
@@ -182,7 +182,7 @@ class MXNET_API Engine { | |||
std::vector<VarHandle> const& mutable_vars, | |||
FnProperty prop = FnProperty::kNormal, | |||
int priority = 0, | |||
const char* opr_name = nullptr) = 0; | |||
const char* opr_name = nullptr, bool wait = false) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Add
\param
forwait
in doc. - Put
bool wait...
in the next line to keep the coding style consistent with the existing context. Same all the following changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed here and other places.
src/engine/naive_engine.cc
Outdated
@@ -73,7 +73,7 @@ class NaiveEngine final : public Engine { | |||
std::vector<VarHandle> const& const_vars, | |||
std::vector<VarHandle> const& mutable_vars, | |||
FnProperty prop = FnProperty::kNormal, | |||
const char* opr_name = nullptr) override { | |||
const char* opr_name = nullptr, bool wait = false) override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put in the next line.
src/engine/naive_engine.cc
Outdated
@@ -125,7 +125,7 @@ class NaiveEngine final : public Engine { | |||
std::vector<VarHandle> const& mutable_vars, | |||
FnProperty prop = FnProperty::kNormal, | |||
int priority = 0, | |||
const char* opr_name = nullptr) override { | |||
const char* opr_name = nullptr, bool wait = false) override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Next line.
inputs = [x, y] | ||
out = mx.symbol.ElementWiseSum(*inputs, name="esum") | ||
out = mx.sym.dot(z, out) | ||
out2 = mx.sym.random_normal(0, -1, x_shape, ctx=default_context()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random.normal as it's the preferred way now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
src/engine/threaded_engine.h
Outdated
if (debug_info) { | ||
LOG(INFO) << "Fin ExecuteOprFn "; | ||
} | ||
} catch(dmlc::Error &e) { | ||
} catch (dmlc::Error& e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would this catch now? There is already a try block inside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong my intention for the outer block was to catch all other exceptions which are not caught in the inner block. I should change dmlc::Error to std::exception to catch all standard exception. But you make a good point about propagating all the exceptions and not just dmlc::Error to the frontend. We can take one of the two approaches here: 1. catch dmlc::Error and terminate the process for everything else. 2. catch both dmlc::Error and standard exceptions and propagate to frontend. This will require code changes to the guards , c_api_error and potentially frontend code.
src/storage/cpu_device_storage.h
Outdated
@@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) { | |||
void* ptr; | |||
#if _MSC_VER | |||
ptr = _aligned_malloc(size, alignment_); | |||
if (ptr == NULL) throw std::bad_alloc(); | |||
if (ptr == NULL) LOG(FATAL) << "Malloc failure"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Failed to allocate CPU memory
src/storage/gpu_device_storage.h
Outdated
@@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) { | |||
#endif // MXNET_USE_NCCL | |||
cudaError_t e = cudaMalloc(&ret, size); | |||
if (e != cudaSuccess && e != cudaErrorCudartUnloading) | |||
throw std::bad_alloc(); | |||
LOG(FATAL) << cudaGetLastError(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does cudaGetLastError return a string or an error code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I have changed it to call cudaGetErrorString
src/engine/threaded_engine.h
Outdated
@@ -476,6 +507,9 @@ class ThreadedEngine : public Engine { | |||
*/ | |||
std::mutex finished_m_; | |||
std::condition_variable finished_cv_; | |||
/*! \brief exception_ptr associated with the engine, | |||
* which is used to throw exception in waitall */ | |||
std::exception_ptr global_ex_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ex_ptr is a bad name. class members should end with _
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed!
suppose and operator has three outputs x, y and z and it raises an exception. I think an error should only be raised once. After it's raised, it should be cleared from all arrays that is pointing to that error. This can be achieved by setting the object referenced by exception_ptr to an invalid value |
@piiswrong Trying to understand your comment. Lets say we have some code snippet like the below:
If we clear the exception_ptr corresponding to the var y when x.asnumpy() is executed, y may have some garbage value in it. op2 may end up executing fine, and after the last line |
@anirudh2290 Good idea to use an example. From my perspective as a user I generally prefer that things fail as quickly as possible for me so that I don't have to track down root causes. Following this logic, I'd tend to agree with you, but would actually prefer it if this line threw an exception: try:
x, y, z = op()
x.asnumpy()
except:
handle_exc()
y = op2(y) # Runtime exception
y.asnumpy() I understand that this is a lazy operation, but wonder if it's still possible to do some failure validation here? If not the first blocking call (asnumpy) would also be a fairly intuitive place to throw. |
What I meant is, for example:
Currently y.asnumpy() will fail again with the same error as x.asnumpy(). |
@KellenSunderland As you mentioned since it is a lazy operation, there is no guarantee that operation is executed, just that it is pushed to the engine. So, there is no guarantee that ExecuteOprBlock is called for the operator. On the other hand, it is guaranteed that all operations which write to a particular variable are executed when the blocking call on that variable is made. Therefore, I have rethrown exceptions in WaitForVar and WaitForAll. I understand that this may not be as intuitive to users as throwing on the |
@piiswrong As depicted in the example above, if we decide to invalidate exception_ptr for y by setting it to nullptr when we WaitToRead x (I am unsure how we will do this), then we won't be propagating exceptions down the chain. Therefore, the last line here will execute just fine instead of throwing an exception, and user will end up with garbage values for y. I understand your point that if an op has multiple write vars, and if we waited for one of the write vars and re-threw exception, we shouldn't throw it again for other vars. But, if we end up invalidating the exception_ptr, any continuing operators may or may not fail, and since the exception_ptr is invalidated we wouldn't be re-throwing the exception in any of the following WaitToReads. |
After discussion with @piiswrong , we came to a conclusion that in an execution graph once an exception is thrown, the same exception should not be thrown again. For example:
You can see that op1 throws an exception and it may end up writing garbage values to x and y. The line x.asnumpy() throws exception. Once this is done, user may handle the exception or keep the garbage values as it is. Any consequent usage of x or op2 should not throw the same exception, since user is not expecting and it is already handled. One challenge during the implementation was that dereferencing exception_ptr in C++ will cause undefined behavior. So there is no way to modify state of the exception_object that exception_ptr points to, just by using exception_ptr itself. To workaround this limitation, we are holding the exception_ptr itself in a shared_ptr object. We decided to remove global_exception thrown in WaitForAll, since it adds unnecessary complexity and is not really used much except during benchmarking. |
@piiswrong: Do you have additional suggestions ? |
* Add support for threaded engine * Add support for threaded engine * Remove on_start_callback for else * Add support for global_ex_ptr * Rethrow in waitall only once * run tests for gpu * Add comments for exception_ptr * Fix lint * Push exc_handling tests * Add comments for OnStart * Fixes for exc handling * Catch std::exception for all other exceptions * Rollback std::move use * Fix style * Fix onstart * Fix debug_info * Throw exception only once in an execution graph * make test naming consistent * Fix symbolic test * Remove unused code
* Add support for threaded engine * Add support for threaded engine * Remove on_start_callback for else * Add support for global_ex_ptr * Rethrow in waitall only once * run tests for gpu * Add comments for exception_ptr * Fix lint * Push exc_handling tests * Add comments for OnStart * Fixes for exc handling * Catch std::exception for all other exceptions * Rollback std::move use * Fix style * Fix onstart * Fix debug_info * Throw exception only once in an execution graph * make test naming consistent * Fix symbolic test * Remove unused code
* Add support for threaded engine * Add support for threaded engine * Remove on_start_callback for else * Add support for global_ex_ptr * Rethrow in waitall only once * run tests for gpu * Add comments for exception_ptr * Fix lint * Push exc_handling tests * Add comments for OnStart * Fixes for exc handling * Catch std::exception for all other exceptions * Rollback std::move use * Fix style * Fix onstart * Fix debug_info * Throw exception only once in an execution graph * make test naming consistent * Fix symbolic test * Remove unused code
Description
Please see: https://cwiki.apache.org/confluence/display/MXNET/Improved+exception+handling+in+MXNet
Implements Exception Handling for Operators.
Fixes #7335 and related issues
Functional Testing
Performance Testing
I did a small performance testing task with resnet50 model on cifar10 dataset to make sure that there is no performance degradation because of the additional overhead of OnStart callback in each ExecuteOprBlock. I don't see any change.
Model: resnet50
Dataset: cifar10
Tested on ec2: p2.8xlarge
Checklist
Essentials
make lint
)Changes
Comments
@cjolivier01 @mli @piiswrong @madjam @asmushetzel @eric-haibin-lin @reminisce @rahul003 @KellenSunderland @eftiquar