Skip to content

Commit

Permalink
Clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
benson31 authored and graham63 committed Jun 25, 2022
1 parent 72ad812 commit 1aa9858
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 42 deletions.
2 changes: 1 addition & 1 deletion include/lbann/layers/operator_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class OperatorLayer final : public data_type_layer<InputT, OutputT>

#ifdef LBANN_HAS_ONNX
void fill_onnx_node(onnx::GraphProto& graph) const override;
#endif //LBANN_HAS_ONNX
#endif // LBANN_HAS_ONNX

void fp_compute() final;
void bp_compute() final;
Expand Down
47 changes: 23 additions & 24 deletions include/lbann/operators/math/binary_with_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ inline onnx::NodeProto get_constant_node(float val)
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
AddConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(AddConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Add");
Expand All @@ -187,8 +187,7 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
ScaleOperator<T, D> const op)
std::vector<onnx::NodeProto> get_onnx_nodes_impl(ScaleOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Mul");
Expand All @@ -198,8 +197,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
SubtractConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(SubtractConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Sub");
Expand All @@ -209,8 +208,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
ConstantSubtractOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(ConstantSubtractOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Sub");
Expand All @@ -220,8 +219,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
MaxConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(MaxConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Max");
Expand All @@ -231,8 +230,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
MinConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(MinConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Min");
Expand All @@ -242,8 +241,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
EqualConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(EqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Equal");
Expand All @@ -253,8 +252,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
NotEqualConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(NotEqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(3UL);
nodes.front().set_op_type("Equal");
Expand All @@ -265,8 +264,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
LessConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(LessConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Less");
Expand All @@ -276,8 +275,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
LessEqualConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(LessEqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("LessOrEqual");
Expand All @@ -287,8 +286,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
GreaterConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(GreaterConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("Greater");
Expand All @@ -298,8 +297,8 @@ std::vector<onnx::NodeProto> get_onnx_nodes_impl(
}

template <typename T, El::Device D>
std::vector<onnx::NodeProto> get_onnx_nodes_impl(
GreaterEqualConstantOperator<T, D> const op)
std::vector<onnx::NodeProto>
get_onnx_nodes_impl(GreaterEqualConstantOperator<T, D> const op)
{
std::vector<onnx::NodeProto> nodes(2UL);
nodes.front().set_op_type("GreaterOrEqual");
Expand Down
3 changes: 2 additions & 1 deletion include/lbann/operators/operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ void Operator<InputT, OutputT, D>::serialize(ArchiveT& ar)

#ifdef LBANN_HAS_ONNX
template <typename InputT, typename OutputT, El::Device D>
std::vector<onnx::NodeProto> Operator<InputT, OutputT, D>::get_onnx_nodes() const
std::vector<onnx::NodeProto>
Operator<InputT, OutputT, D>::get_onnx_nodes() const
{
// The default assumption is that we don't know how to represent
// this operator in ONNX terms yet.
Expand Down
26 changes: 10 additions & 16 deletions src/layers/operator_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ namespace lbann {

#ifdef LBANN_HAS_ONNX
template <typename T, typename O, data_layout L, El::Device D>
void OperatorLayer<T, O, L, D>::fill_onnx_node(
onnx::GraphProto& graph) const
void OperatorLayer<T, O, L, D>::fill_onnx_node(onnx::GraphProto& graph) const
{
const auto& parents = this->get_parent_layers();
auto nodes = m_ops.front()->get_onnx_nodes();
Expand All @@ -58,27 +57,22 @@ void OperatorLayer<T, O, L, D>::fill_onnx_node(
op_node->set_domain("");
op_node->set_doc_string(this->get_name());

//binary operators
if(nodes.size() == 1)
{
for(auto* parent : parents)
{
// binary operators
if (nodes.size() == 1) {
for (auto* parent : parents) {
size_t idx = parent->find_child_layer_index(*this);
op_node->add_input(parent->get_name() + "_" + std::to_string(idx));
}
}
// Binary w/ constant operators
else if(nodes.size() == 2 || nodes.size() == 3)
{
else if (nodes.size() == 2 || nodes.size() == 3) {
auto* const_node = graph.add_node();
*const_node = nodes.back();
if(const_node->op_type() == "PostConstant")
{
if (const_node->op_type() == "PostConstant") {
op_node->add_input(parents[0]->get_name() + "_0");
op_node->add_input(const_node->output(0));
}
else if(const_node->op_type() == "PreConstant")
{
else if (const_node->op_type() == "PreConstant") {
op_node->add_input(const_node->output(0));
op_node->add_input(parents[0]->get_name() + "_0");
}
Expand All @@ -88,11 +82,11 @@ void OperatorLayer<T, O, L, D>::fill_onnx_node(
const_node->set_op_type("Constant");
}
else
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ", nodes.size());
LBANN_ERROR("Expected 1-3 ONNX nodes for binary operation, received ",
nodes.size());

// Not equal operator
if(nodes.size() == 3)
{
if (nodes.size() == 3) {
op_node->add_output("EqualOperator");
auto* not_node = graph.add_node();
not_node->add_input(op_node->output(0));
Expand Down

0 comments on commit 1aa9858

Please sign in to comment.