Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix for enable model parallelism for non-fp32 data (#16683)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmushetzel authored and szha committed Dec 4, 2019
1 parent 78a2523 commit 02b4d2b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
10 changes: 10 additions & 0 deletions src/operator/cross_device_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class CrossDeviceCopyProp : public OperatorProperty {
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const {
CHECK_EQ(in_type->size(), 1) << "Input:[data]";
if (in_type->at(0) == -1) return false;
out_type->clear();
out_type->push_back(in_type->at(0));
return true;
}

OperatorProperty* Copy() const override {
return new CrossDeviceCopyProp();
}
Expand Down
38 changes: 22 additions & 16 deletions tests/python/unittest/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import mxnet as mx
from mxnet.test_utils import *

def reldiff(a, b):
diff = np.sum(np.abs(a - b))
Expand All @@ -26,16 +27,14 @@ def reldiff(a, b):
reldiff = diff / norm
return reldiff

def test_chain():
ctx1 = mx.cpu(0)
ctx2 = mx.cpu(1)
def test_chain(ctx1=mx.cpu(0), ctx2=mx.cpu(1), dtype=np.float32):
n = 2
data1 = mx.sym.Variable('data1')
data2 = mx.sym.Variable('data2')
data3 = mx.sym.Variable('data3')
data1 = mx.sym.Variable('data1', dtype=dtype)
data2 = mx.sym.Variable('data2', dtype=dtype)
data3 = mx.sym.Variable('data3', dtype=dtype)
with mx.AttrScope(ctx_group='dev1'):
net = data1 + data2
net = net * 3
net = net * dtype(3)

with mx.AttrScope(ctx_group='dev2'):
net = net + data3
Expand All @@ -45,19 +44,19 @@ def test_chain():
shape = (4, 5)
with mx.Context(ctx1):
for i in range(n):
arr.append(mx.nd.empty(shape))
arr_grad.append(mx.nd.empty(shape))
arr.append(mx.nd.empty(shape, dtype=dtype))
arr_grad.append(mx.nd.empty(shape, dtype=dtype))
with mx.Context(ctx2):
arr.append(mx.nd.empty(shape))
arr_grad.append(mx.nd.empty(shape))
arr.append(mx.nd.empty(shape, dtype=dtype))
arr_grad.append(mx.nd.empty(shape, dtype=dtype))

exec1 = net.bind(ctx1,
args=arr,
args_grad=arr_grad,
group2ctx={'dev1': ctx1, 'dev2': ctx2})
arr[0][:] = 1.0
arr[1][:] = 2.0
arr[2][:] = 3.0
arr[0][:] = dtype(1)
arr[1][:] = dtype(2)
arr[2][:] = dtype(3)
arr2 = [a.copyto(ctx1) for a in arr]
arr_grad2 = [a.copyto(ctx1) for a in arr_grad]
exec2 = net.bind(ctx1,
Expand All @@ -70,12 +69,19 @@ def test_chain():
exec2.forward(is_train=True)
assert reldiff(exec1.outputs[0].asnumpy(), exec2.outputs[0].asnumpy()) < 1e-6
out_grad = mx.nd.empty(shape, ctx1)
out_grad[:] = 1.0
out_grad[:] = dtype(1)
exec1.backward([out_grad])
exec2.backward([out_grad.copyto(ctx1)])
for a, b in zip(arr_grad, arr_grad2):
assert reldiff(a.asnumpy(), b.asnumpy()) < 1e-6

def test_chain_type_device():
ctx_pairs = [(mx.cpu(0), mx.cpu(1))]
if default_context().device_type == 'gpu':
ctx_pairs = ctx_pairs + [(mx.gpu(0), mx.gpu(0)), (mx.cpu(0), mx.gpu(0)), (mx.gpu(0), mx.cpu(0))]
for ctx1, ctx2 in ctx_pairs:
for dtype in [np.float16, np.float32, np.float64]:
test_chain(ctx1, ctx2, dtype)

if __name__ == '__main__':
test_chain()
test_chain_type_device()

0 comments on commit 02b4d2b

Please sign in to comment.