Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Properly handling custom op exception by modify engine #14693

Merged
merged 6 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_MP_OPENCV_NUM_THREADS
- Values: Int ```(default=0)```
- The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
* MXNET_CUSTOM_OP_NUM_THREADS
- Values: Int ```(default=16)```
- The maximum number of threads given to custom operators.

## Memory Options

Expand Down
6 changes: 5 additions & 1 deletion include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ enum class FnProperty {
/*! \brief Delete variable call */
kDeleteVar,
/*! \brief Prioritized sync operation on GPU */
kGPUPrioritized
kGPUPrioritized,
/*! \brief Operation not to be skipped even with associated exception */
kNoSkip
}; // enum class FnProperty

/*!
Expand Down Expand Up @@ -230,6 +232,8 @@ class MXNET_API Engine {
* \brief Wait until all the activity of engine finishes.
*/
virtual void WaitForAll() = 0;
/*!\brief Throw if threre are associated exception with var */
virtual void Throw(VarHandle var) = 0;
/*!\brief virtual destructor */
virtual ~Engine() noexcept(false) {}
/*!
Expand Down
3 changes: 3 additions & 0 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ class NaiveEngine final : public Engine {
void WaitForAll() override {
}

void Throw(VarHandle var) override {
}

void NotifyShutdown() override {
shutdown_phase_.store(true);
}
Expand Down
5 changes: 5 additions & 0 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,11 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
return;
}

void ThreadedEngine::Throw(VarHandle var) {
ThreadedVar *threaded_var = ThreadedVar::CastFromBase(var);
ThrowException(threaded_var);
}
szha marked this conversation as resolved.
Show resolved Hide resolved

void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
const dmlc::Error* error) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
Expand Down
5 changes: 3 additions & 2 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ class ThreadedEngine : public Engine {
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;
void Throw(VarHandle var) override;
void NotifyShutdown() override {
shutdown_phase_.store(true);
}
Expand Down Expand Up @@ -374,8 +375,8 @@ class ThreadedEngine : public Engine {
LOG(INFO) << "ExecuteOprFn ";
}
try {
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->wait) {
if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->prop == FnProperty::kNoSkip) || threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
Expand Down
45 changes: 35 additions & 10 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ class CustomOperator {
bool prev_recording = Imperative::Get()->set_is_recording(recording);
bool prev_training = Imperative::Get()->set_is_training(training);

func();
try {
func();
} catch (dmlc::Error& e) {
exception_ =
std::make_shared<std::exception_ptr>(std::current_exception());
}

Imperative::Get()->set_is_training(prev_training);
Imperative::Get()->set_is_recording(prev_recording);
Expand All @@ -116,6 +121,16 @@ class CustomOperator {

Engine::Get()->PushSync(
[=](RunContext rctx) {
try {
Throw();
for (const auto& i : arrs) {
Engine::Get()->Throw(i.var());
}
szha marked this conversation as resolved.
Show resolved Hide resolved
} catch(dmlc::Error& err) {
ctx.async_on_complete(&err);
return;
}

for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
Expand All @@ -125,14 +140,15 @@ class CustomOperator {
out_idx++;
}
}

ctx.async_on_complete();
},
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNoSkip, 0,
"CustomOperator");
});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads)
CreateThreads(q_.size() - num_free_threads);
if (q_.size() > num_free_threads_)
CreateThreads(q_.size() - num_free_threads_);
cv_.notify_all();
}

Expand All @@ -142,9 +158,10 @@ class CustomOperator {
}

void Start() {
num_free_threads = 0;
num_free_threads_ = 0;
destructing_ = false;
naive_engine_ = true;
exception_ = nullptr;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
}
Expand All @@ -162,6 +179,14 @@ class CustomOperator {
workers_.clear();
}

inline void Throw() {
if (exception_ && *exception_) {
std::exception_ptr tmp = *exception_;
exception_ = nullptr;
std::rethrow_exception(tmp);
}
}

private:
CustomOperator() {
this->Start();
Expand All @@ -171,21 +196,20 @@ class CustomOperator {
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
--num_free_threads;
--num_free_threads_;
auto fn = q_.front();
q_.pop();
lock.unlock();
fn();
++num_free_threads;
++num_free_threads_;
lock.lock();
}
}
}
void SetNumThreads(int num_threads) {
num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads);
for (int i = workers_.size(); i < num_threads; ++i) {
workers_.emplace_back(std::thread([this]{this->ThreadTarget();}));
++num_free_threads;
++num_free_threads_;
}
}
void CreateThreads(int num_new_threads) {
Expand All @@ -196,8 +220,9 @@ class CustomOperator {
// async worker
std::condition_variable cv_;
std::vector<std::thread> workers_;
std::atomic<uint32_t> num_free_threads;
std::atomic<uint32_t> num_free_threads_;
std::queue<std::function<void(void)> > q_;
std::shared_ptr<std::exception_ptr> exception_;
bool naive_engine_;
bool destructing_;
};
Expand Down
129 changes: 107 additions & 22 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from mxnet.test_utils import *
from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
from common import run_in_spawned_process
from nose.tools import assert_raises
import unittest
import os

Expand Down Expand Up @@ -5355,29 +5357,29 @@ def create_operator(self, ctx, shapes, dtypes):

# test custom operator fork
# see https://github.com/apache/incubator-mxnet/issues/14396
if not sys.platform.startswith('win'): # no fork in windows
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
szha marked this conversation as resolved.
Show resolved Hide resolved
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()

if not sys.platform.startswith('win'): # no fork in windows
def custom_add():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5, 6])
Expand All @@ -5392,6 +5394,89 @@ def custom_add():
p.join(5)
assert not p.is_alive(), "deadlock may exist in custom operator"


def _build_dot_custom(fun_forward, name):
class Dot(mx.operator.CustomOp):
def __init__(self):
super(Dot, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
fun_forward(in_data, out_data)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
pass

@mx.operator.register(name)
class DotProp(mx.operator.CustomOpProp):
def __init__(self):
super(DotProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [(in_shape[0][0], in_shape[1][1])]
def create_operator(self, ctx, shapes, dtypes):
return Dot()

def _custom_exc3(seed):
def custom_exc3():
def f(in_data, out_data):
out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
out_data[0].wait_to_read()
_build_dot_custom(f, 'Dot3')
n = int(1e8)
a = mx.nd.zeros((n, 1))
b = mx.nd.zeros((1, n))
# trigger OOM
c = mx.nd.Custom(a, b, op_type='Dot3')
c.wait_to_read()
assert_raises(MXNetError, custom_exc3)

def _custom_exc4(seed):
def custom_exc4():
def f(in_data, out_data):
out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
_build_dot_custom(f, 'Dot4')
n = int(1e8)
a = mx.nd.zeros((n, 1))
b = mx.nd.zeros((1, n))
# trigger OOM
c = mx.nd.Custom(a, b, op_type='Dot4')
c.wait_to_read()
assert_raises(MXNetError, custom_exc4)

@with_seed()
def test_custom_op_exc():
# test except handling
# see https://github.com/apache/incubator-mxnet/pull/14693
# 1. error in python code
def custom_exc1():
def f(in_data, out_data):
assert False
out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
_build_dot_custom(f, 'Dot1')
a = mx.nd.zeros((4, 1))
b = mx.nd.zeros((1, 4))
c = mx.nd.Custom(a, b, op_type='Dot1')
c.wait_to_read()
assert_raises(MXNetError, custom_exc1)

# 2. error in pushing operator to engine
def custom_exc2():
def f(in_data, out_data):
out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
_build_dot_custom(f, 'Dot2')
a = mx.nd.zeros((4, 2))
b = mx.nd.zeros((1, 4))
# trigger error by invalid input shapes of operands
c = mx.nd.Custom(a, b, op_type='Dot2')
c.wait_to_read()
assert_raises(MXNetError, custom_exc2)

# 3. error in real execution
run_in_spawned_process(_custom_exc3, {})
run_in_spawned_process(_custom_exc4, {})


@with_seed()
def test_psroipooling():
for num_rois in [1, 2]:
Expand Down