Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Extension bug fixes #19469

Merged
merged 4 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,14 +1029,35 @@ def _build_cache(self, *args):

arg_dict, aux_dict = dict(), dict()
if self._backend:
ctx = args[0].context
# set context for inputs
_, _, ctx_set, _ = _gather_type_ctx_info(list(args))
ctx = ctx_set.pop() if len(ctx_set) > 0 else None
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
# get list of params in the order of out.list_arguments
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()})
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
input_shapes = dict()
for name in out.list_arguments():
if name in data_names.keys() and data_names[name] < len(args):
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(args[data_names[name]], NDArray):
arg_dict[name] = args[data_names[name]]
elif (isinstance(args[data_names[name]], symbol.Symbol) and
'__shape__' in args[data_names[name]].list_attr()):
shape_str = args[data_names[name]].list_attr()['__shape__']
input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
elif name in params:
arg_dict[name] = params[name].data()

for name in out.list_auxiliary_states():
if name in data_names.keys() and data_names[name] < len(args):
if isinstance(args[data_names[name]], NDArray):
aux_dict[name] = args[data_names[name]]
elif (isinstance(args[data_names[name]], symbol.Symbol) and
'__shape__' in args[data_names[name]].list_attr()):
shape_str = args[data_names[name]].list_attr()['__shape__']
input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
elif name in params:
aux_dict[name] = params[name].data()

# Partition the graph
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts)

# convert to numpy symbol if needed
if _mx_npx.is_np_array():
Expand Down Expand Up @@ -1079,7 +1100,7 @@ def _build_cache(self, *args):
param = Parameter(name, dtype=param_data.dtype)
param._var_name = name
serialization_name = name # HybridBlock.export
param._load_init(param_data, args[0].context)
param._load_init(param_data, param_data.context)
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
triple = (False, serialization_name, param)

self._cached_op_args.append(triple)
Expand Down Expand Up @@ -1181,14 +1202,11 @@ def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **

# do part of forward API call
has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
if has_symbol:
raise ValueError('Inputs must be NDArrays for the optimize_for API'
' Please check the type of the args.\n')
if not has_symbol and not has_ndarray:
raise ValueError('In HybridBlock, there must be one NDArray as input.'
raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
' Please check the type of the args.\n')
if len(ctx_set) > 1:
raise ValueError('Find multiple contexts in the input, '
raise ValueError('Found multiple contexts in the input, '
'After hybridized, the HybridBlock only supports one input '
'context. You can print the ele.ctx in the '
'input arguments to inspect their contexts. '
Expand Down
68 changes: 36 additions & 32 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1377,50 +1377,54 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,

// convert input args
for (size_t i=0; i < in_arg_names.size(); i++) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);
if (in_args_ptr[i] != nullptr) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
}
}

// convert input aux
for (size_t i=0; i < in_aux_names.size(); i++) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);
if (in_aux_ptr[i] != nullptr) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
}
}

// convert graph to string
Expand Down
74 changes: 41 additions & 33 deletions src/operator/subgraph/partitioner/custom_subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,26 +208,28 @@ class CustomSubgraphProperty: public SubgraphProperty {
arg_dev_type.clear();
arg_dev_id.clear();
for (size_t i=0; i < in_arg_names.size(); i++) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);
if (in_args_ptr[i] != nullptr) {
arg_names.push_back(in_arg_names[i].c_str());
const NDArray &in_arg = *(in_args_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_arg.IsMKLDNNData()) {
in_arg.Reorder2DefaultAsync();
in_arg.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
arg_data.push_back(in_arg.data().dptr_);
arg_shapes.push_back(in_arg.shape().data());
arg_dims.push_back(in_arg.shape().ndim());
arg_types.push_back(in_arg.dtype());
arg_verIDs.push_back(in_arg.version());
const char* arg_ctx_str = in_arg.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
arg_dev_type.push_back(arg_ctx_str);
arg_dev_id.push_back(in_arg.ctx().real_dev_id());
}
}

// convert input aux
Expand All @@ -240,26 +242,28 @@ class CustomSubgraphProperty: public SubgraphProperty {
aux_dev_type.clear();
aux_dev_id.clear();
for (size_t i=0; i < in_aux_names.size(); i++) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);
if (in_aux_ptr[i] != nullptr) {
aux_names.push_back(in_aux_names[i].c_str());
const auto &in_aux = *(in_aux_ptr[i]);

#if MXNET_USE_MKLDNN == 1
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
// reorder data if in MKLDNN format
if (in_aux.IsMKLDNNData()) {
in_aux.Reorder2DefaultAsync();
in_aux.WaitToRead();
}
#endif

// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
// pull out parts of NDArray to send to backend
aux_data.push_back(in_aux.data().dptr_);
aux_shapes.push_back(in_aux.shape().data());
aux_dims.push_back(in_aux.shape().ndim());
aux_types.push_back(in_aux.dtype());
aux_verIDs.push_back(in_aux.version());
const char* aux_ctx_str = in_aux.ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
aux_dev_type.push_back(aux_ctx_str);
aux_dev_id.push_back(in_aux.ctx().real_dev_id());
}
}

// remove all graph attrs, some cannot be saved to json
Expand All @@ -285,13 +289,17 @@ class CustomSubgraphProperty: public SubgraphProperty {
for (unsigned oid = 0; oid < node->num_outputs(); oid++) {
const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid);
mxnet::TShape& shape = shapes[out_entry_id];
ss << shape;
if (shape.ndim() == -1)
ss << "[None]";
else
ss << shape;
if (oid < node->num_outputs()-1) ss << ",";
}
ss << "]";
node->attrs.dict[MX_STR_SHAPE] = ss.str();
}
}

// set dtype attrs for each node in the graph
if (g.HasAttr("dtype")) {
std::vector<int> dtypes = g.GetAttr<std::vector<int> >("dtype");
Expand Down
17 changes: 12 additions & 5 deletions tests/python/unittest/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,23 @@ def test_subgraph():
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out4[0].asnumpy(), rtol=1e-3, atol=1e-3)

# Gluon Hybridize partitioning with shapes/types
# Gluon Hybridize partitioning with sym.var
sym_block2 = nn.SymbolBlock(sym, [a,b])
sym_block2.initialize()
a_var = mx.sym.var('a',shape=(3,2))
b_var = mx.sym.var('b',shape=(3,2))
sym_block2.optimize_for(a_var, b_var, backend='myProp')

# Gluon Hybridize partitioning with shapes/types
sym_block3 = nn.SymbolBlock(sym, [a,b])
sym_block3.initialize()
a_data = mx.nd.ones((3,2))
b_data = mx.nd.ones((3,2))
sym_block2.optimize_for(a_data, b_data, backend='myProp')
sym_block2.export('optimized')
sym_block3 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
sym_block3.optimize_for(a_data, b_data, backend='myProp')
sym_block3.export('optimized')
sym_block4 = nn.SymbolBlock.imports('optimized-symbol.json',['a','b'],
'optimized-0000.params')

out5 = sym_block3(a_data, b_data)
out5 = sym_block4(a_data, b_data)
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)