Skip to content

Commit

Permalink
fix shape/dtype/storage inference. (apache#15)
Browse files Browse the repository at this point in the history
* update tests.

* fix shape/dtype/storage inference.

* fix.
  • Loading branch information
zheng-da authored and reminisce committed Jun 20, 2018
1 parent 1293828 commit 987f145
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 51 deletions.
90 changes: 73 additions & 17 deletions src/operator/subgraph/subgraph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,27 @@ bool SubgraphOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
// TODO: make sure shape inputs matches the order from in_shapes
nnvm::ShapeVector shape_inputs = *in_shapes;
imperative::CheckAndInferShape(&g, std::move(shape_inputs), true);
const nnvm::ShapeVector& shapes = g.GetAttr<nnvm::ShapeVector>("shape");
const std::vector<uint32_t>& input_nids = idx_g.input_nodes();

// Put the input and output shapes to the shape vector.
nnvm::ShapeVector shapes(idx_g.num_node_entries());
const auto &input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_shapes->size());
for (size_t i = 0; i < in_shapes->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
shapes[eid] = in_shapes->at(i);
}
CHECK_EQ(g.outputs.size(), out_shapes->size());
for (size_t i = 0; i < out_shapes->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
shapes[eid] = out_shapes->at(i);
}

// Infer shape of the graph.
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
g = exec::InferShape(std::move(g));

// Copy the inferred shape back to the input shapes and the output shapes.
shapes = g.GetAttr<nnvm::ShapeVector>("shape");
// assign to in_shapes
for (size_t i = 0; i < in_shapes->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
Expand All @@ -168,7 +185,8 @@ bool SubgraphOpShape(const nnvm::NodeAttrs& attrs,
const auto eid = idx_g.entry_id(g.outputs[i]);
SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
}
return true;
// Check if we have inferred the shapes correctly.
return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
}

bool SubgraphOpType(const nnvm::NodeAttrs& attrs,
Expand All @@ -181,10 +199,26 @@ bool SubgraphOpType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
CHECK_EQ(idx_g.outputs().size(), out_types->size());
// TODO: make sure type inputs matches the order from in_types
nnvm::DTypeVector type_inputs = *in_types;
imperative::CheckAndInferType(&g, std::move(type_inputs), true);
const nnvm::DTypeVector& types = g.GetAttr<nnvm::DTypeVector>("dtype");
const std::vector<uint32_t>& input_nids = idx_g.input_nodes();

// Put the input and output data types to the dtype vector.
nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
const auto &input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_types->size());
for (size_t i = 0; i < in_types->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
types[eid] = in_types->at(i);
}
CHECK_EQ(g.outputs.size(), out_types->size());
for (size_t i = 0; i < out_types->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
types[eid] = out_types->at(i);
}

// Infer data type of the graph.
g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
g = exec::InferType(std::move(g));

types = g.GetAttr<nnvm::DTypeVector>("dtype");
// assign to in_types
for (size_t i = 0; i < in_types->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
Expand All @@ -195,7 +229,8 @@ bool SubgraphOpType(const nnvm::NodeAttrs& attrs,
const auto eid = idx_g.entry_id(g.outputs[i]);
TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
}
return true;
// Check if we have inferred the dtypes correctly.
return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
}

bool SubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
Expand All @@ -209,13 +244,33 @@ bool SubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
const auto& idx_g = g.indexed_graph();
CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
exec::DevMaskVector dev_masks(idx_g.num_nodes(), dev_mask);
exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
// TODO: make sure type inputs matches the order from in_types
StorageTypeVector stype_inputs = *in_stypes;
imperative::CheckAndInferStorageType(&g, std::move(dev_masks),
std::move(stype_inputs), true);
const StorageTypeVector& stypes = g.GetAttr<StorageTypeVector>("storage_type");
const std::vector<uint32_t>& input_nids = idx_g.input_nodes();

// Put the input and output storages to the storage vector.
nnvm::StorageVector stypes(idx_g.num_node_entries(), exec::kBadStorageID);
const auto &input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_stypes->size());
for (size_t i = 0; i < in_stypes->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
stypes[eid] = in_stypes->at(i);
}
CHECK_EQ(g.outputs.size(), out_stypes->size());
for (size_t i = 0; i < out_stypes->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
stypes[eid] = out_stypes->at(i);
}

// Infer storage type of the graph.
bool dev_match = g.attrs.count("dev_mask") &&
g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
if (!dev_match) {
g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
}
g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
g = exec::InferStorageType(std::move(g));

stypes = g.GetAttr<StorageTypeVector>("storage_type");
// assign to in_types
for (size_t i = 0; i < in_stypes->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
Expand All @@ -228,7 +283,8 @@ bool SubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
const auto eid = idx_g.entry_id(g.outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
}
return true;
// Check if we have inferred the storages correctly.
return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
}

void SubgraphOpForward(const OpStatePtr& state_ptr,
Expand Down
69 changes: 35 additions & 34 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,46 @@
import numpy as np


def test_subgraph_op_whole_graph():
data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
data3 = mx.sym.sin(data2)

regular_sym = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)

out = SymbolHandle()

op_names = []
op_names = [mx.sym.sin.__name__, mx.sym.Convolution.__name__]

check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)),
c_str_array(op_names), ctypes.byref(out)))

subgraph_sym = Symbol(out)
assert regular_sym.list_inputs() == subgraph_sym.list_inputs()
input_names = subgraph_sym.list_inputs()
def test_subgraph():
def get_graph():
data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
data3 = mx.sym.sin(data2)
conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)
rets = []
rets.append((conv, []))
rets.append((conv, [mx.sym.sin.__name__]))
rets.append((conv, [mx.sym.Convolution.__name__]))
rets.append((conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__]))
return rets

for regular_sym, op_names in get_graph():
input_names = regular_sym.list_inputs()
shapes = regular_sym.infer_shape()
types = regular_sym.infer_type()
out = SymbolHandle()

#assert regular_sym.list_outputs() == subgraph_sym.list_outputs()
print(subgraph_sym.list_outputs())
assert regular_sym.infer_shape() == subgraph_sym.infer_shape()
assert regular_sym.infer_type() == subgraph_sym.infer_type()
check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)),
c_str_array(op_names), ctypes.byref(out)))
subgraph_sym = Symbol(out)
assert input_names == subgraph_sym.list_inputs()

regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
print(subgraph_sym.list_outputs())
assert shapes == subgraph_sym.infer_shape()
assert types == subgraph_sym.infer_type()

regular_exec.arg_dict[data1.name][:] = mx.nd.random.normal(shape=regular_exec.arg_dict[data1.name].shape)
regular_exec.arg_dict[data2.name][:] = mx.nd.random.normal(shape=regular_exec.arg_dict[data2.name].shape)
regular_exec.arg_dict[input_names[-1]][:] = mx.nd.random.normal(shape=regular_exec.arg_dict[input_names[-1]].shape)
regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null')

subgraph_exec.arg_dict[data1.name][:] = regular_exec.arg_dict[data1.name]
subgraph_exec.arg_dict[data2.name][:] = regular_exec.arg_dict[data2.name]
subgraph_exec.arg_dict[input_names[-1]][:] = regular_exec.arg_dict[input_names[-1]]
for name in input_names:
regular_exec.arg_dict[name][:] = mx.nd.random.normal(
shape=regular_exec.arg_dict[name].shape)
subgraph_exec.arg_dict[name][:] = regular_exec.arg_dict[name]

subgraph_exec.forward()
regular_exec.forward()
mx.nd.waitall()
assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0
subgraph_exec.forward()
regular_exec.forward()
mx.nd.waitall()
assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0


def test_input_name_order():
Expand Down

0 comments on commit 987f145

Please sign in to comment.