Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix custom exception handling #14575

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <utility>
#include "./threaded_engine.h"
#include "../common/cuda_utils.h"
#include "../operator/custom/custom-inl.h"

namespace mxnet {
namespace engine {
Expand Down Expand Up @@ -373,10 +374,12 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
}

void ThreadedEngine::WaitForVar(VarHandle var) {
using mxnet::op::custom::CustomOperator;
BulkFlush();
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) {
ThrowException(threaded_var);
CustomOperator::Get()->ThrowException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets remove this. ThreadedEngine should not depend on custom operator code.

return;
}
if (engine_info_) {
Expand Down Expand Up @@ -407,6 +410,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
}

ThrowException(threaded_var);
CustomOperator::Get()->ThrowException();
}

void ThreadedEngine::WaitForAll() {
Expand Down
31 changes: 26 additions & 5 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ namespace mxnet {
namespace op {
namespace custom {

struct CustomTask {
std::function<void(void)> fn;
mxnet::engine::CallbackOnComplete on_complete;
};

class CustomOperator {
public:
void Register(const std::string &op_type, CustomOpPropCreator creator) {
Expand Down Expand Up @@ -92,7 +97,7 @@ class CustomOperator {
return;
}
std::unique_lock<std::mutex> lock(mutex_);
q_.push([=]() mutable {
q_.push({[=]() mutable {
bool prev_recording = Imperative::Get()->set_is_recording(recording);
bool prev_training = Imperative::Get()->set_is_training(training);

Expand Down Expand Up @@ -129,7 +134,7 @@ class CustomOperator {
},
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
"CustomOperator");
});
}, ctx.async_on_complete});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads)
CreateThreads(q_.size() - num_free_threads);
Expand All @@ -145,6 +150,7 @@ class CustomOperator {
num_free_threads = 0;
destructing_ = false;
naive_engine_ = true;
exception_ = nullptr;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
}
Expand All @@ -162,6 +168,14 @@ class CustomOperator {
workers_.clear();
}

inline void ThrowException() {
if (exception_ && *exception_) {
std::exception_ptr tmp = *exception_;
exception_ = nullptr;
std::rethrow_exception(tmp);
}
}

private:
CustomOperator() {
this->Start();
Expand All @@ -172,10 +186,16 @@ class CustomOperator {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
--num_free_threads;
auto fn = q_.front();
auto task = q_.front();
q_.pop();
lock.unlock();
fn();
try {
task.fn();
} catch (dmlc::Error& e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other types of exceptions? Could we add a comment in the code clarifying this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other types? I think the only valid exception is dmlc::Error, other exception means wrong code.

exception_ =
std::make_shared<std::exception_ptr>(std::current_exception());
task.on_complete();
}
++num_free_threads;
lock.lock();
}
Expand All @@ -197,7 +217,8 @@ class CustomOperator {
std::condition_variable cv_;
std::vector<std::thread> workers_;
std::atomic<uint32_t> num_free_threads;
std::queue<std::function<void(void)> > q_;
std::queue<CustomTask> q_;
std::shared_ptr<std::exception_ptr> exception_;
bool naive_engine_;
bool destructing_;
};
Expand Down
56 changes: 34 additions & 22 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mxnet.test_utils import *
from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
from nose.tools import assert_raises
import unittest
import os

Expand Down Expand Up @@ -5200,29 +5201,29 @@ def create_operator(self, ctx, shapes, dtypes):

# test custom operator fork
# see https://github.com/apache/incubator-mxnet/issues/14396
if not sys.platform.startswith('win'): # no fork in windows
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()

if not sys.platform.startswith('win'): # no fork in windows
def custom_add():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5, 6])
Expand All @@ -5237,6 +5238,17 @@ def custom_add():
p.join(5)
assert not p.is_alive(), "deadlock may exist in custom operator"

# test except handling
# see https://github.com/apache/incubator-mxnet/pull/14575
def custom_add_exc():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment that an exception is expected due to shapes I assume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add it.

a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5])
c = mx.nd.Custom(a, b, op_type='AdditionOP')
c.wait_to_read()

assert_raises(MXNetError, custom_add_exc)


@with_seed()
def test_psroipooling():
for num_rois in [1, 2]:
Expand Down