Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed Mar 17, 2019
1 parent bbc5414 commit b973f8d
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5198,6 +5198,49 @@ def create_operator(self, ctx, shapes, dtypes):
x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op")
assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32))

# test custom operator fork
# see https://github.com/apache/incubator-mxnet/issues/14396
class AdditionOP(mx.operator.CustomOp):
def __init__(self):
super(AdditionOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
out_data[0][:] = in_data[0] + in_data[1]
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = out_grad[0]

@mx.operator.register("AdditionOP")
class AdditionOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(AdditionOPProp, self).__init__()
def list_arguments(self):
return ['a', 'b']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return AdditionOP()

def custom_add():
a = mx.nd.array([1, 2, 3])
b = mx.nd.array([4, 5, 6])
a.attach_grad()
b.attach_grad()

with mx.autograd.record():
c = mx.nd.Custom(a, b, op_type='AdditionOP')

dc = mx.nd.array([7, 8, 9])
c.backward(dc)

custom_add()
from multiprocessing import Process
p = Process(target=custom_add)
p.start()
p.join(5)
assert not p.is_alive(), "deadlock may exist in custom operator"

@with_seed()
def test_psroipooling():
for num_rois in [1, 2]:
Expand Down

0 comments on commit b973f8d

Please sign in to comment.