Skip to content

Commit

Permalink
Extension bug fixes (apache#19469)
Browse files Browse the repository at this point in the history
* initial commit

* syntax fix

* spacing

* added test case
  • Loading branch information
samskalicky committed Nov 9, 2020
1 parent 15a8864 commit d8bb575
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 84 deletions.
46 changes: 32 additions & 14 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,14 +957,35 @@ def _build_cache(self, *args):

arg_dict, aux_dict = dict(), dict()
if self._backend:
ctx = args[0].context
# set context for inputs
_, _, ctx_set, _ = _gather_type_ctx_info(list(args))
ctx = ctx_set.pop() if len(ctx_set) > 0 else None
# get list of params in the order of out.list_arguments
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()})
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
input_shapes = dict()
for name in out.list_arguments():
if name in data_names.keys() and data_names[name] < len(args):
if isinstance(args[data_names[name]], NDArray):
arg_dict[name] = args[data_names[name]]
elif (isinstance(args[data_names[name]], symbol.Symbol) and
'__shape__' in args[data_names[name]].list_attr()):
shape_str = args[data_names[name]].list_attr()['__shape__']
input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
elif name in params:
arg_dict[name] = params[name].data()

for name in out.list_auxiliary_states():
if name in data_names.keys() and data_names[name] < len(args):
if isinstance(args[data_names[name]], NDArray):
aux_dict[name] = args[data_names[name]]
elif (isinstance(args[data_names[name]], symbol.Symbol) and
'__shape__' in args[data_names[name]].list_attr()):
shape_str = args[data_names[name]].list_attr()['__shape__']
input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
elif name in params:
aux_dict[name] = params[name].data()

# Partition the graph
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts)

#update cached graph with partitioned graph
self._cached_graph = data, out
Expand Down Expand Up @@ -1000,8 +1021,8 @@ def _build_cache(self, *args):
'Please check the backend.')

param = Parameter(name, dtype=param_data.dtype)
param._load_init(param_data, args[0].context)
pair = (False, param)
param._load_init(param_data, param_data.context)
pair = (False, param)

self._cached_op_args.append(pair)

Expand Down Expand Up @@ -1103,14 +1124,11 @@ def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **

# do part of forward API call
has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
if has_symbol:
raise ValueError('Inputs must be NDArrays for the optimize_for API'
' Please check the type of the args.\n')
if not has_symbol and not has_ndarray:
raise ValueError('In HybridBlock, there must be one NDArray as input.'
raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
' Please check the type of the args.\n')
if len(ctx_set) > 1:
raise ValueError('Find multiple contexts in the input, '
raise ValueError('Found multiple contexts in the input, '
'After hybridized, the HybridBlock only supports one input '
'context. You can print the ele.ctx in the '
'input arguments to inspect their contexts. '
Expand Down
68 changes: 36 additions & 32 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1376,50 +1376,54 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,

// convert input args
for (size_t i=0; i < in_arg_names.size(); i++) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);
if (in_args_ptr[i] != nullptr) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
}
}

// convert input aux
for (size_t i=0; i < in_aux_names.size(); i++) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);
if (in_aux_ptr[i] != nullptr) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
}
}

// convert graph to string
Expand Down
74 changes: 41 additions & 33 deletions src/operator/subgraph/partitioner/custom_subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,26 +208,28 @@ class CustomSubgraphProperty: public SubgraphProperty {
arg_dev_type.clear();
arg_dev_id.clear();
for (size_t i=0; i < in_arg_names.size(); i++) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);
if (in_args_ptr[i] != nullptr) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
}
}

// convert input aux
Expand All @@ -240,26 +242,28 @@ class CustomSubgraphProperty: public SubgraphProperty {
aux_dev_type.clear();
aux_dev_id.clear();
for (size_t i=0; i < in_aux_names.size(); i++) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);
if (in_aux_ptr[i] != nullptr) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
}
}

// remove all graph attrs, some cannot be saved to json
Expand All @@ -285,13 +289,17 @@ class CustomSubgraphProperty: public SubgraphProperty {
for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
mxnet::TShape& shape = shapes[out_entry_id];
ss << shape;
if (shape.ndim() == -1)
ss << "[None]";
else
ss << shape;
if (oid < node->num_outputs()-1) ss << ",";
}
ss << "]";
node->attrs.dict[MX_STR_SHAPE] = ss.str();
}
}

// set dtype attrs for each node in the graph
if (g.HasAttr("dtype")) {
std::vector<int> dtypes = g.GetAttr<std::vector<int> >("dtype");
Expand Down
17 changes: 12 additions & 5 deletions tests/python/unittest/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,23 @@ def test_subgraph():
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out4[0].asnumpy(), rtol=1e-3, atol=1e-3)

# Gluon Hybridize partitioning with shapes/types
# Gluon Hybridize partitioning with sym.var
sym_block2 = nn.SymbolBlock(sym, [a,b])
sym_block2.initialize()
a_var = mx.sym.var('a',shape=(3,2))
b_var = mx.sym.var('b',shape=(3,2))
sym_block2.optimize_for(a_var, b_var, backend='myProp')

# Gluon Hybridize partitioning with shapes/types
sym_block3 = nn.SymbolBlock(sym, [a,b])
sym_block3.initialize()
a_data = mx.nd.ones((3,2))
b_data = mx.nd.ones((3,2))
sym_block2.optimize_for(a_data, b_data, backend='myProp')
sym_block2.export('optimized')
sym_block3 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
sym_block3.optimize_for(a_data, b_data, backend='myProp')
sym_block3.export('optimized')
sym_block4 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
'optimized-0000.params')

out5 = sym_block3(a_data, b_data)
out5 = sym_block4(a_data, b_data)
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)

0 comments on commit d8bb575

Please sign in to comment.