-
Notifications
You must be signed in to change notification settings - Fork 6.8k
fix custom exception handling #14575
Changes from 2 commits
da74a63
2b1da9e
e43f13d
5bce27e
23aee98
963be55
931aca9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,11 @@ namespace mxnet { | |
namespace op { | ||
namespace custom { | ||
|
||
struct CustomTask { | ||
std::function<void(void)> fn; | ||
mxnet::engine::CallbackOnComplete on_complete; | ||
}; | ||
|
||
class CustomOperator { | ||
public: | ||
void Register(const std::string &op_type, CustomOpPropCreator creator) { | ||
|
@@ -92,7 +97,7 @@ class CustomOperator { | |
return; | ||
} | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
q_.push([=]() mutable { | ||
q_.push({[=]() mutable { | ||
bool prev_recording = Imperative::Get()->set_is_recording(recording); | ||
bool prev_training = Imperative::Get()->set_is_training(training); | ||
|
||
|
@@ -129,7 +134,7 @@ class CustomOperator { | |
}, | ||
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0, | ||
"CustomOperator"); | ||
}); | ||
}, ctx.async_on_complete}); | ||
// increase num_threads if there is not enough threads to execute custom operator | ||
if (q_.size() > num_free_threads) | ||
CreateThreads(q_.size() - num_free_threads); | ||
|
@@ -145,6 +150,7 @@ class CustomOperator { | |
num_free_threads = 0; | ||
destructing_ = false; | ||
naive_engine_ = true; | ||
exception_ = nullptr; | ||
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) { | ||
naive_engine_ = false; | ||
} | ||
|
@@ -162,6 +168,14 @@ class CustomOperator { | |
workers_.clear(); | ||
} | ||
|
||
inline void ThrowException() { | ||
if (exception_ && *exception_) { | ||
std::exception_ptr tmp = *exception_; | ||
exception_ = nullptr; | ||
std::rethrow_exception(tmp); | ||
} | ||
} | ||
|
||
private: | ||
CustomOperator() { | ||
this->Start(); | ||
|
@@ -172,10 +186,16 @@ class CustomOperator { | |
cv_.wait(lock, [&] {return !q_.empty() || destructing_;}); | ||
while (!q_.empty()) { | ||
--num_free_threads; | ||
auto fn = q_.front(); | ||
auto task = q_.front(); | ||
q_.pop(); | ||
lock.unlock(); | ||
fn(); | ||
try { | ||
task.fn(); | ||
} catch (dmlc::Error& e) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about other types of exceptions? Could we add a comment in the code clarifying this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What other types? I think the only valid exception is dmlc::Error, other exception means wrong code. |
||
exception_ = | ||
std::make_shared<std::exception_ptr>(std::current_exception()); | ||
task.on_complete(); | ||
} | ||
++num_free_threads; | ||
lock.lock(); | ||
} | ||
|
@@ -197,7 +217,8 @@ class CustomOperator { | |
std::condition_variable cv_; | ||
std::vector<std::thread> workers_; | ||
std::atomic<uint32_t> num_free_threads; | ||
std::queue<std::function<void(void)> > q_; | ||
std::queue<CustomTask> q_; | ||
std::shared_ptr<std::exception_ptr> exception_; | ||
bool naive_engine_; | ||
bool destructing_; | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
from mxnet.test_utils import * | ||
from mxnet.base import py_str, MXNetError, _as_list | ||
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises | ||
from nose.tools import assert_raises | ||
import unittest | ||
import os | ||
|
||
|
@@ -5200,29 +5201,29 @@ def create_operator(self, ctx, shapes, dtypes): | |
|
||
# test custom operator fork | ||
# see https://github.com/apache/incubator-mxnet/issues/14396 | ||
if not sys.platform.startswith('win'): # no fork in windows | ||
class AdditionOP(mx.operator.CustomOp): | ||
def __init__(self): | ||
super(AdditionOP, self).__init__() | ||
def forward(self, is_train, req, in_data, out_data, aux): | ||
out_data[0][:] = in_data[0] + in_data[1] | ||
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): | ||
in_grad[0][:] = out_grad[0] | ||
in_grad[1][:] = out_grad[0] | ||
|
||
@mx.operator.register("AdditionOP") | ||
class AdditionOPProp(mx.operator.CustomOpProp): | ||
def __init__(self): | ||
super(AdditionOPProp, self).__init__() | ||
def list_arguments(self): | ||
return ['a', 'b'] | ||
def list_outputs(self): | ||
return ['output'] | ||
def infer_shape(self, in_shape): | ||
return in_shape, [in_shape[0]] | ||
def create_operator(self, ctx, shapes, dtypes): | ||
return AdditionOP() | ||
class AdditionOP(mx.operator.CustomOp): | ||
def __init__(self): | ||
super(AdditionOP, self).__init__() | ||
def forward(self, is_train, req, in_data, out_data, aux): | ||
out_data[0][:] = in_data[0] + in_data[1] | ||
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): | ||
in_grad[0][:] = out_grad[0] | ||
in_grad[1][:] = out_grad[0] | ||
|
||
@mx.operator.register("AdditionOP") | ||
class AdditionOPProp(mx.operator.CustomOpProp): | ||
def __init__(self): | ||
super(AdditionOPProp, self).__init__() | ||
def list_arguments(self): | ||
return ['a', 'b'] | ||
def list_outputs(self): | ||
return ['output'] | ||
def infer_shape(self, in_shape): | ||
return in_shape, [in_shape[0]] | ||
def create_operator(self, ctx, shapes, dtypes): | ||
return AdditionOP() | ||
|
||
if not sys.platform.startswith('win'): # no fork in windows | ||
def custom_add(): | ||
a = mx.nd.array([1, 2, 3]) | ||
b = mx.nd.array([4, 5, 6]) | ||
|
@@ -5237,6 +5238,17 @@ def custom_add(): | |
p.join(5) | ||
assert not p.is_alive(), "deadlock may exist in custom operator" | ||
|
||
# test except handling | ||
# see https://github.com/apache/incubator-mxnet/pull/14575 | ||
def custom_add_exc(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment that an exception is expected due to shapes I assume? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'll add it. |
||
a = mx.nd.array([1, 2, 3]) | ||
b = mx.nd.array([4, 5]) | ||
c = mx.nd.Custom(a, b, op_type='AdditionOP') | ||
c.wait_to_read() | ||
|
||
assert_raises(MXNetError, custom_add_exc) | ||
|
||
|
||
@with_seed() | ||
def test_psroipooling(): | ||
for num_rois in [1, 2]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets remove this. ThreadedEngine should not depend on custom operator code.