Skip to content

Commit

Permalink
added binary operators to fill_onnx_node, refactored binary_with cons…
Browse files Browse the repository at this point in the history
…tant for consistency with location of operator in vector for regular binary operators
  • Loading branch information
graham63 committed Apr 26, 2022
1 parent adc287f commit 57f114f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 67 deletions.
72 changes: 36 additions & 36 deletions include/lbann/operators/math/binary_with_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
AddConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PostConstant");
nodes.back().set_op_type("Add");
nodes.front().set_op_type("Add");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PostConstant");
return nodes;
}

Expand All @@ -191,9 +191,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
ScaleOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PostConstant");
nodes.back().set_op_type("Mul");
nodes.front().set_op_type("Mul");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PostConstant");
return nodes;
}

Expand All @@ -202,9 +202,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
SubtractConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PostConstant");
nodes.back().set_op_type("Sub");
nodes.front().set_op_type("Sub");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PostConstant");
return nodes;
}

Expand All @@ -213,9 +213,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
ConstantSubtractOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.back().set_op_type("Sub");
nodes.front().set_op_type("Sub");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
return nodes;
}

Expand All @@ -224,9 +224,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
MaxConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.back().set_op_type("Max");
nodes.front().set_op_type("Max");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
return nodes;
}

Expand All @@ -235,9 +235,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
MinConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.back().set_op_type("Min");
nodes.front().set_op_type("Min");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
return nodes;
}

Expand All @@ -246,9 +246,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
EqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.back().set_op_type("Equal");
nodes.front().set_op_type("Equal");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
return nodes;
}

Expand All @@ -257,10 +257,10 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
NotEqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(3UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.front().set_op_type("Equal");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
nodes.at(1).set_op_type("Not");
nodes.back().set_op_type("Equal");
return nodes;
}

Expand All @@ -269,9 +269,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
LessConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PostConstant");
nodes.back().set_op_type("Less");
nodes.front().set_op_type("Less");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PostConstant");
return nodes;
}

Expand All @@ -280,9 +280,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
LessEqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PostConstant");
nodes.back().set_op_type("LessOrEqual");
nodes.front().set_op_type("LessOrEqual");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PostConstant");
return nodes;
}

Expand All @@ -291,9 +291,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
GreaterConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.back().set_op_type("Greater");
nodes.front().set_op_type("Greater");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
return nodes;
}

Expand All @@ -302,9 +302,9 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
GreaterEqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front() = get_constant_node(El::To<float>(op.get_constant()));
nodes.front().set_op_type("PreConstant");
nodes.back().set_op_type("GreaterOrEqual");
nodes.front().set_op_type("GreaterOrEqual");
nodes.back() = get_constant_node(El::To<float>(op.get_constant()));
nodes.back().set_op_type("PreConstant");
return nodes;
}

Expand Down
70 changes: 39 additions & 31 deletions src/layers/operator_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,57 +48,65 @@ template <typename T, typename O, data_layout L, El::Device D>
void OperatorLayer<T, O, L, D>::fill_onnx_node(
onnx::GraphProto& graph) const
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().add_attribute()->set_type(onnx::AttributeProto::FLOAT);
nodes.front().add_attribute()->set_f(El::To<float>(5));
nodes.front().set_op_type("PostConstant");
nodes.back().set_op_type("Add");
const auto& parents = this->get_parent_layers();
auto nodes = m_ops.front()->get_onnx_nodes();

//OperatorPtr op;
//auto nodes = op->get_onnx_nodes();
const auto* parent = this->get_parent_layers()[0];
auto* op_node = graph.add_node();
*op_node = nodes.front();

auto* const_node = graph.add_node();
*const_node = nodes.front();
op_node->set_name(this->get_name());
op_node->set_domain("");
op_node->set_doc_string(this->get_name());

auto* node = graph.add_node();
*node = nodes.back();
node->set_name(this->get_name());
node->set_domain("");
node->set_doc_string(this->get_name());
if(const_node->op_type() == "PostConstant")
//binary operators
if(nodes.size() == 1)
{
node->add_input(parent->get_name() + "_0");
node->add_input(const_node->output(0));
const_node->set_op_type("Constant");
for(auto* parent : parents)
{
size_t idx = parent->find_child_layer_index(*this);
op_node->add_input(parent->get_name() + "_" + std::to_string(idx));
}
}
else if(const_node->op_type() == "PreConstant")
// Binary w/ constant operators
else if(nodes.size() == 2 || nodes.size() == 3)
{
node->add_input(const_node->output(0));
node->add_input(parent->get_name() + "_0");
auto* const_node = graph.add_node();
*const_node = nodes.back();
if(const_node->op_type() == "PostConstant")
{
op_node->add_input(parents[0]->get_name() + "_0");
op_node->add_input(const_node->output(0));
}
else if(const_node->op_type() == "PreConstant")
{
op_node->add_input(const_node->output(0));
op_node->add_input(parents[0]->get_name() + "_0");
}
else
LBANN_ERROR("Unknown onnx op type for constant.");

const_node->set_op_type("Constant");
}
else
LBANN_ERROR("Unknown onnx op type for constant.");
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size());

// Not equal operator
if(nodes.size() == 3)
{
node->add_output("EqualOperator");
op_node->add_output("EqualOperator");
auto* not_node = graph.add_node();
not_node->add_input(node->output(0));
not_node->add_output(this->get_child_layers()[0]->get_name() + "_0");
not_node->add_input(op_node->output(0));
not_node->set_name("Not operator");
not_node->set_op_type("Not");
not_node->set_domain("");
not_node->set_doc_string("Not node for not equal operation.");
op_node = not_node;
}
else if(nodes.size() == 2)
{
node->add_output(this->get_child_layers()[0]->get_name() + "_0");

for (auto const* child : this->get_child_layers()) {
auto idx = this->find_child_layer_index(*child);
op_node->add_output(this->get_name() + "_" + std::to_string(idx));
}
else
LBANN_ERROR("Expected two or three nodes for binary constant operation, received ", nodes.size());
}
#endif // LBANN_HAS_ONNX

Expand Down

0 comments on commit 57f114f

Please sign in to comment.