Skip to content

Commit

Permalink
[PASS] Make placedevice compatible with backward op (apache#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and piiswrong committed Dec 23, 2016
1 parent 1f1147a commit f17fea0
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/pass/place_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Graph PlaceDevice(Graph src) {
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();
static auto& is_backward =
Op::GetAttr<TIsBackward>("TIsBackward");
DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
Expand All @@ -45,9 +47,16 @@ Graph PlaceDevice(Graph src) {
<< "The device assignment not found for group " << device_group;
device[nid] = dit->second;
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
if (!inode.source->is_variable() &&
is_backward.get(inode.source->op(), false)) {
if (device[inode.control_deps[0]] != -1) {
device[nid] = device[inode.control_deps[0]];
}
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
}
}
}
}
Expand Down

0 comments on commit f17fea0

Please sign in to comment.