Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix custom operation in fork (#14451)
Browse files Browse the repository at this point in the history
* fix custom operation in fork

* add test

* fix custom stop

* swap order

* add docs

* update doc
  • Loading branch information
arcadiaphy authored and nswamy committed Apr 5, 2019
1 parent 708931a commit 4b256aa
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 14 deletions.
20 changes: 20 additions & 0 deletions docs/tutorials/gluon/customop.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Custom operator in python is easy to develop and good for prototyping, but may h
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd
import os
```

## Parameter-less operators
Expand Down Expand Up @@ -214,5 +215,24 @@ y = dense(x)
print(y)
```

## Using custom operators with fork
In Linux systems, the default method in multiprocessing to create process is by using fork. If there are unfinished async custom operations when forking, the program will be blocked because of python GIL. Always use sync calls like `wait_to_read` or `waitall` before calling fork.

```
x = mx.nd.array([0, 1, 2, 3])
y = mx.nd.Custom(x, op_type='sigmoid')
# unfinished async sigmoid operation will cause blocking
os.fork()
```

Correctly handling this will make mxnet depend upon libpython, so the workaround now is to ensure that all custom operations are executed before forking process.

```
x = mx.nd.array([0, 1, 2, 3])
y = mx.nd.Custom(x, op_type='sigmoid')
# force execution by reading y
print(y.asnumpy())
os.fork()
```

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
1 change: 1 addition & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ int MXRandomSeedContext(int seed, int dev_type, int dev_id) {

int MXNotifyShutdown() {
API_BEGIN();
mxnet::op::custom::CustomOperator::Get()->Stop();
Engine::Get()->NotifyShutdown();
API_END();
}
Expand Down
5 changes: 5 additions & 0 deletions src/initialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/logging.h>
#include <mxnet/engine.h>
#include "./engine/openmp.h"
#include "./operator/custom/custom-inl.h"
#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
#endif // MXNET_USE_OPENCV
Expand Down Expand Up @@ -53,12 +54,15 @@ class LibraryInitializer {

// disable openmp for multithreaded workers
#ifndef _WIN32
using op::custom::CustomOperator;
pthread_atfork(
[]() {
CustomOperator::Get()->Stop();
Engine::Get()->Stop();
},
[]() {
Engine::Get()->Start();
CustomOperator::Get()->Start();
},
[]() {
// Conservative thread management for multiprocess workers
Expand All @@ -71,6 +75,7 @@ class LibraryInitializer {
#endif // MXNET_USE_OPENCV
engine::OpenMP::Get()->set_enabled(false);
Engine::Get()->Start();
CustomOperator::Get()->Start();
});
#endif
}
Expand Down
27 changes: 18 additions & 9 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,21 @@ class CustomOperator {
cv_.notify_all();
}

~CustomOperator() {
static CustomOperator* Get() {
static CustomOperator inst;
return &inst;
}

void Start() {
num_free_threads = 0;
destructing_ = false;
naive_engine_ = true;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
}
}

void Stop() {
if (naive_engine_) return;
{
std::unique_lock<std::mutex> lock(mutex_);
Expand All @@ -145,17 +159,12 @@ class CustomOperator {
}
for (auto &worker : workers_)
worker.join();
workers_.clear();
}

static CustomOperator* Get();

private:
CustomOperator() : num_free_threads(0) {
destructing_ = false;
naive_engine_ = true;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
}
CustomOperator() {
this->Start();
}
void ThreadTarget() {
std::unique_lock<std::mutex> lock(mutex_);
Expand Down
5 changes: 0 additions & 5 deletions src/operator/custom/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ namespace mxnet {
namespace op {
namespace custom {

CustomOperator* CustomOperator::Get() {
static CustomOperator inst;
return &inst;
}

struct CustomParam {
std::string op_type;
size_t num_args, num_outs, num_auxs;
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5198,6 +5198,45 @@ def create_operator(self, ctx, shapes, dtypes):
x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op")
assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32))

# test custom operator fork
# see https://github.com/apache/incubator-mxnet/issues/14396
if not sys.platform.startswith('win'): # no fork in windows
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()

def custom_add():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5, 6])
c = mx.nd.Custom(a, b, op_type='AdditionOP')
assert_almost_equal((a + b).asnumpy(), c.asnumpy())

custom_add()
from multiprocessing import Process
p = Process(target=custom_add)
p.daemon = True
p.start()
p.join(5)
assert not p.is_alive(), "deadlock may exist in custom operator"

@with_seed()
def test_psroipooling():
for num_rois in [1, 2]:
Expand Down

0 comments on commit 4b256aa

Please sign in to comment.