Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix custom exception handling #14575

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_MP_OPENCV_NUM_THREADS
- Values: Int ```(default=0)```
- The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
* MXNET_CUSTOM_OP_NUM_THREADS
- Values: Int ```(default=16)```
- The maximum number of threads given to custom operators.

## Memory Options

Expand Down
70 changes: 35 additions & 35 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,43 +96,44 @@ class CustomOperator {
bool prev_recording = Imperative::Get()->set_is_recording(recording);
bool prev_training = Imperative::Get()->set_is_training(training);

func();
try {
func();

size_t idx = 0;
for (const auto& i : arrs) {
i.WaitToRead();
if (output_tags.count(tags[idx]) > 0) {
if (i.storage_type() == kDefaultStorage ||
i.storage_type() == kUndefinedStorage)
continue;
i.WaitToWrite();
idx++;
}
}
} catch (dmlc::Error& err) {
Imperative::Get()->set_is_training(prev_training);
Imperative::Get()->set_is_recording(prev_recording);
ctx.async_on_complete(&err);
return;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can solve both 1 and 2 this way: After func is called do wait_to_read on all elements in arrs. Then catch and save. Remove lines 104 and 105. In PushSync, check if exception is set and rethrow exception. Also catch it and call async_on_complete in pushsync. and return.

Something like the following:

Engine::Get()->PushSync(
     [=](RunContext rctx) {
    try {
    if (exception_) {
       std::rethrow_exception(exception_);
    }
    } catch(dmlc::Error& err) {
         ctx.async_on_complete(&err);
         return;
    }
 }

Thanks to this support added for horovod: #13932 we may be able to leverage this to call async_on_complete with the error.

Copy link
Member Author

@arcadiaphy arcadiaphy Apr 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding wait_to_read in custom op can solve 1&2, and it can be treated as normal op without using ExecType::kAsync.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably still need PushSync for the Sparse ndarray updates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need ExecType::kAsync. Custom operator is still async and when push is called it just pushes it into its custom op worker queue for execution later. Async will ensure that the threaded_engine_pooled and threaded_engine_per_device treat it as a special case and execute immediately instead of pushing the work again to one of the engine worker thread queue. Pushing to engine worker thread queue is unnecessary for custom op.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After testing, ExecType::kAsync is really needed. Adding wait_to_read in engine worker thread will cause deadlock.
But PushSync can be removed and works well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably still need it for sparse. since for sparse we are updating chunk it is a write option. WaitToRead may not be enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also add WaitToWrite to make sure there's no left out exceptions.


Imperative::Get()->set_is_training(prev_training);
Imperative::Get()->set_is_recording(prev_recording);

std::vector<Engine::VarHandle> vars, vars2;
size_t idx = 0;
for (const auto& i : arrs) {
vars.push_back(i.var());
if (output_tags.count(tags[idx]) > 0) {
if (i.storage_type() == kDefaultStorage ||
i.storage_type() == kUndefinedStorage)
continue;
vars2.push_back(i.var());
idx++;
for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
continue;
if (output_tags.count(tags[i]) > 0) {
outputs[out_idx].SparseUpdateChunk(arrs[i]);
out_idx++;
}
}

Engine::Get()->PushSync(
[=](RunContext rctx) {
for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
continue;
if (output_tags.count(tags[i]) > 0) {
outputs[out_idx].SparseUpdateChunk(arrs[i]);
out_idx++;
}
}
ctx.async_on_complete();
},
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
"CustomOperator");
ctx.async_on_complete();
});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads)
CreateThreads(q_.size() - num_free_threads);
if (q_.size() > num_free_threads_)
CreateThreads(q_.size() - num_free_threads_);
cv_.notify_all();
}

Expand All @@ -142,7 +143,7 @@ class CustomOperator {
}

void Start() {
num_free_threads = 0;
num_free_threads_ = 0;
destructing_ = false;
naive_engine_ = true;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
Expand Down Expand Up @@ -171,21 +172,20 @@ class CustomOperator {
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
--num_free_threads;
--num_free_threads_;
auto fn = q_.front();
q_.pop();
lock.unlock();
fn();
++num_free_threads;
++num_free_threads_;
lock.lock();
}
}
}
void SetNumThreads(int num_threads) {
num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads);
for (int i = workers_.size(); i < num_threads; ++i) {
workers_.emplace_back(std::thread([this]{this->ThreadTarget();}));
++num_free_threads;
++num_free_threads_;
}
}
void CreateThreads(int num_new_threads) {
Expand All @@ -196,7 +196,7 @@ class CustomOperator {
// async worker
std::condition_variable cv_;
std::vector<std::thread> workers_;
std::atomic<uint32_t> num_free_threads;
std::atomic<uint32_t> num_free_threads_;
std::queue<std::function<void(void)> > q_;
bool naive_engine_;
bool destructing_;
Expand Down
57 changes: 35 additions & 22 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mxnet.test_utils import *
from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
from nose.tools import assert_raises
import unittest
import os

Expand Down Expand Up @@ -5200,29 +5201,29 @@ def create_operator(self, ctx, shapes, dtypes):

# test custom operator fork
# see https://github.com/apache/incubator-mxnet/issues/14396
if not sys.platform.startswith('win'): # no fork in windows
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()

if not sys.platform.startswith('win'): # no fork in windows
def custom_add():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5, 6])
Expand All @@ -5237,6 +5238,18 @@ def custom_add():
p.join(5)
assert not p.is_alive(), "deadlock may exist in custom operator"

# test except handling
# see https://github.com/apache/incubator-mxnet/pull/14575
def custom_add_exc():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment that an exception is expected due to shapes I assume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add it.

a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5])
# trigger exception by providing unmatched operand shapes
c = mx.nd.Custom(a, b, op_type='AdditionOP')
c.wait_to_read()

assert_raises(MXNetError, custom_add_exc)


@with_seed()
def test_psroipooling():
for num_rois in [1, 2]:
Expand Down