Skip to content

Commit

Permalink
Place device now compatible and tested (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Aug 31, 2016
1 parent a175bd7 commit 163450f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
3 changes: 2 additions & 1 deletion include/dmlc/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ inline const std::type_info& any::type() const {
template<typename T>
inline void any::check_type() const {
CHECK(type_ != nullptr)
<< "The any container is empty";
<< "The any container is empty"
<< " requested=" << typeid(T).name();
CHECK(type_->ptype_info == &typeid(T))
<< "The stored type mismatch"
<< " stored=" << type_->ptype_info->name()
Expand Down
34 changes: 25 additions & 9 deletions include/dmlc/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ class FieldEntry;
// forward declare ParamManagerSingleton
template<typename PType>
struct ParamManagerSingleton;

/*! \brief option in parameter initialization */
enum ParamInitOption {
/*! \brief allow unknown parameters */
kAllowUnknown,
/*! \brief need to match exact parameters */
kAllMatch
};
} // namespace parameter
/*!
* \brief Information about a parameter field in string representations.
Expand Down Expand Up @@ -108,13 +116,17 @@ struct Parameter {
* and throw error if something wrong happens.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \parma option The option on initialization.
* \tparam Container container type
* \throw ParamError when something go wrong.
*/
template<typename Container>
inline void Init(const Container &kwargs) {
inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowUnknown) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), NULL);
kwargs.begin(), kwargs.end(),
NULL,
option == parameter::kAllowUnknown);
}
/*!
* \brief initialize the parameter by keyword arguments.
Expand All @@ -130,7 +142,8 @@ struct Parameter {
InitAllowUnknown(const Container &kwargs) {
std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), &unknown);
kwargs.begin(), kwargs.end(),
&unknown, true);
return unknown;
}
/*!
Expand Down Expand Up @@ -355,7 +368,8 @@ class ParamManager {
inline void RunInit(void *head,
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args) const {
std::vector<std::pair<std::string, std::string> > *unknown_args,
bool allow_unknown) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
Expand All @@ -367,11 +381,13 @@ class ParamManager {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
PrintDocString(os);
throw dmlc::ParamError(os.str());
if (!allow_unknown) {
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
PrintDocString(os);
throw dmlc::ParamError(os.str());
}
}
}
}
Expand Down
36 changes: 28 additions & 8 deletions src/pass/place_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Graph PlaceDevice(Graph src) {
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();

DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
Expand Down Expand Up @@ -79,17 +78,27 @@ Graph PlaceDevice(Graph src) {
src.attrs["device"] = std::make_shared<any>(std::move(device));
return src;
}

std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
std::unordered_map<const Node*, int> new_device_map;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");

// insert copy node
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
int dev_id = device[nid];
const auto& inode = idx[nid];
// check if mutation is needed
bool need_mutate = false;
if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) {
for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) {
auto e = inode.inputs[index];
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
LOG(FATAL) << " mutable state cannot go across device"
<< " op=" << inode.source->op()->name
<< " input_state_index=" << index;
}
}
}
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
need_mutate = true; break;
Expand All @@ -102,6 +111,9 @@ Graph PlaceDevice(Graph src) {
}
}
}
if (inode.source->is_variable()) {
CHECK(!need_mutate) << "consistency check";
}
if (need_mutate) {
NodePtr new_node = Node::Create();
new_node->attrs = inode.source->attrs;
Expand All @@ -120,7 +132,15 @@ Graph PlaceDevice(Graph src) {
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.push_back(inode.source->inputs[i]);
if (new_node_map[e.node_id] != nullptr) {
copy_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
copy_node->inputs.push_back(inode.source->inputs[i]);
}
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
copy_map[copy_key] = copy_node;
new_device_map[copy_node.get()] = dev_id;
new_node->inputs.emplace_back(
Expand All @@ -130,7 +150,7 @@ Graph PlaceDevice(Graph src) {
if (new_node_map[e.node_id] != nullptr) {
new_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
} else {
new_node->inputs.push_back(inode.source->inputs[i]);
}
}
Expand All @@ -150,7 +170,6 @@ Graph PlaceDevice(Graph src) {
new_device_map[inode.source] = dev_id;
}
}

// make the new graph
Graph ret;
for (const NodeEntry& e : src.outputs) {
Expand All @@ -163,10 +182,11 @@ Graph PlaceDevice(Graph src) {
}
DeviceVector new_device_vec(ret.indexed_graph().num_nodes());
for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) {
if (new_device_map.count(ret.indexed_graph()[nid].source) == 0) {
LOG(INFO) << "canot find " << ret.indexed_graph()[nid].source->attrs.name;
auto source = ret.indexed_graph()[nid].source;
if (new_device_map.count(source) == 0) {
LOG(FATAL) << "canot find " << source;
}
new_device_vec[nid] = new_device_map.at(ret.indexed_graph()[nid].source);
new_device_vec[nid] = new_device_map.at(source);
}
ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec));
return ret;
Expand Down

0 comments on commit 163450f

Please sign in to comment.