From 828bda060e1edf3eb05c73d8e25c5f0449158b2b Mon Sep 17 00:00:00 2001 From: iroul Date: Thu, 17 Sep 2015 04:16:01 -0700 Subject: [PATCH] sanity test passed --- include/mxnet/operator.h | 4 ++ include/mxnet/symbolic.h | 36 ++++++++++ src/operator/activation-inl.h | 6 +- src/operator/batch_norm-inl.h | 4 +- src/operator/concat-inl.h | 5 +- src/operator/convolution-inl.h | 4 +- src/operator/dropout-inl.h | 4 +- src/operator/elementwise_binary_op-inl.h | 2 +- src/operator/elementwise_sum-inl.h | 15 ++-- src/operator/fully_connected-inl.h | 5 +- src/operator/leaky_relu-inl.h | 5 +- src/operator/lrn-inl.h | 5 +- src/operator/operator_common.h | 40 +++++++++++ src/operator/param.h | 1 - src/operator/pooling-inl.h | 5 +- src/operator/reshape-inl.h | 7 +- src/operator/softmax-inl.h | 5 +- src/symbol/static_graph.cc | 87 ++++++++++++++++++++++++ 18 files changed, 191 insertions(+), 49 deletions(-) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index a62d97425da0..3700a513a546 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -8,6 +8,7 @@ #define MXNET_OPERATOR_H_ #include +#include #include #include #include @@ -389,6 +390,9 @@ class OperatorProperty { * \return a new constructed OperatorProperty */ static OperatorProperty *Create(const char* type_name); + + virtual void Save(dmlc::JSONWriter *writer) const = 0; + virtual void Load(dmlc::JSONReader *reader) = 0; }; /*! \brief typedef the factory function of operator property */ diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index ad01040007a7..e8e942eda534 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -8,6 +8,8 @@ #define MXNET_SYMBOLIC_H_ #include +#include +#include #include #include #include @@ -64,6 +66,11 @@ class StaticGraph { if (source_id == other.source_id) return index < other.index; return source_id < other.source_id; } + + /*! \brief interface for json serialization */ + void Save(dmlc::JSONWriter *writer) const; + /*! \brief interface for json serialization */ + void Load(dmlc::JSONReader *reader); }; /*! * \brief Operation Node in static graphs. @@ -95,6 +102,23 @@ class StaticGraph { int32_t backward_source_id; /*! \brief default constructor */ Node() : backward_source_id(-1) {} + + friend void swap(Node& lhs, Node& rhs) { + std::swap(lhs.op, rhs.op); + std::swap(lhs.name, rhs.name); + std::swap(lhs.inputs, rhs.inputs); + std::swap(lhs.backward_source_id, rhs.backward_source_id); + } + /*! \brief copy constructor in favor of serialization. */ + Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr), + name(another.name), + inputs(another.inputs), + backward_source_id(another.backward_source_id) {} + + inline Node& operator=(Node another) { + swap(*this, another); + return *this; + } /*! \return whether the node is forward op node */ inline bool is_forward() const { return op != nullptr; @@ -107,6 +131,10 @@ class StaticGraph { inline bool is_variable() const { return op == nullptr && !is_backward(); } + /*! \brief interface for json serialization */ + void Save(dmlc::JSONWriter *writer) const; + /*! \brief interface for json serialization */ + void Load(dmlc::JSONReader *reader); }; /*! \brief all nodes in the graph */ std::vector nodes; @@ -114,6 +142,14 @@ class StaticGraph { std::vector arg_nodes; /*! \brief heads outputs of the graph */ std::vector heads; + /*! \brief load static graph from json. TODO: a static creator's better */ + void Load(const std::string& json); + /*! \brief save static graph to json */ + void Save(std::string* json) const; + /*! \brief interface for json serialization */ + void Save(dmlc::JSONWriter *writer) const; + /*! \brief interface for json serialization */ + void Load(dmlc::JSONReader *reader); // funtions to help inference in static graph /*! * \brief Perform a topological sort on the graph diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 98445b629b9e..9105c37fd1b2 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -84,7 +84,7 @@ template Operator* CreateOp(ActivationParam type); #if DMLC_USE_CXX11 -class ActivationProp : public OperatorProperty { +class ActivationProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -139,12 +139,8 @@ class ActivationProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ActivationParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_ACTIVATION_INL_H_ - diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index c6bca12bd0db..3827dbb909e6 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -182,7 +182,7 @@ Operator *CreateOp(BatchNormParam param); #if DMLC_USE_CXX11 -class BatchNormProp : public OperatorProperty { +class BatchNormProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -263,8 +263,6 @@ class BatchNormProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - private: - BatchNormParam param_; }; // class BatchNormProp #endif // DMLC_USE_CXX11 diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h index dc18da329bf4..1a69a261d9bb 100644 --- a/src/operator/concat-inl.h +++ b/src/operator/concat-inl.h @@ -163,7 +163,7 @@ template Operator *CreateOp(ConcatParam param); #if DMLC_USE_CXX11 -class ConcatProp : public OperatorProperty { +class ConcatProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -223,9 +223,6 @@ class ConcatProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ConcatParam param_; }; // class ConcatProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index f26265f3f3f1..ca7c4e609e73 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -262,7 +262,7 @@ template Operator* CreateOp(ConvolutionParam param); #if DMLC_USE_CXX11 -class ConvolutionProp : public OperatorProperty { +class ConvolutionProp : public ParamOperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { @@ -358,8 +358,6 @@ class ConvolutionProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - private: - ConvolutionParam param_; }; // class ConvolutionProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/dropout-inl.h b/src/operator/dropout-inl.h index 512cf7f3d10d..6afdba146e03 100644 --- a/src/operator/dropout-inl.h +++ b/src/operator/dropout-inl.h @@ -91,7 +91,7 @@ template Operator *CreateOp(DropoutParam param); #if DMLC_USE_CXX11 -class DropoutProp : public OperatorProperty { +class DropoutProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -160,8 +160,6 @@ class DropoutProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - private: - DropoutParam param_; }; // class DropoutProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/elementwise_binary_op-inl.h b/src/operator/elementwise_binary_op-inl.h index f8136af7b156..9cd6ef57be11 100644 --- a/src/operator/elementwise_binary_op-inl.h +++ b/src/operator/elementwise_binary_op-inl.h @@ -157,7 +157,7 @@ Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type); #if DMLC_USE_CXX11 template -class ElementWiseBinaryOpProp : public OperatorProperty { +class ElementWiseBinaryOpProp : public NoParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { CHECK_EQ(kwargs.size(), 0) diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h index 4e73d7e77efd..ebd31f155158 100644 --- a/src/operator/elementwise_sum-inl.h +++ b/src/operator/elementwise_sum-inl.h @@ -102,6 +102,16 @@ class ElementWiseSumOp : public Operator { Assign(igrad, req[i], F(ograd)); } } + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("size_", size_); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("size_", &size_); + helper.ReadAllFields(reader); + } private: int size_; @@ -111,7 +121,7 @@ template Operator* CreateOp(ElementWiseSumParam param); #if DMLC_USE_CXX11 -class ElementWiseSumProp : public OperatorProperty { +class ElementWiseSumProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -180,9 +190,6 @@ class ElementWiseSumProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ElementWiseSumParam param_; }; // class ElementWiseSumProp #endif // DMLC_USE_CXX11 diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index bde719d4ed4e..dfc718596103 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -124,7 +124,7 @@ template Operator* CreateOp(FullyConnectedParam param); #if DMLC_USE_CXX11 -class FullyConnectedProp : public OperatorProperty { +class FullyConnectedProp : public ParamOperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { @@ -189,9 +189,6 @@ class FullyConnectedProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - FullyConnectedParam param_; }; // class FullyConnectedSymbol #endif } // namespace op diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index ba5a874213cf..5f4ed83990ea 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -190,7 +190,7 @@ template Operator* CreateOp(LeakyReLUParam type); #if DMLC_USE_CXX11 -class LeakyReLUProp : public OperatorProperty { +class LeakyReLUProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -298,9 +298,6 @@ class LeakyReLUProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - LeakyReLUParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/lrn-inl.h b/src/operator/lrn-inl.h index 06476a4ce4ee..7b9326bd6892 100644 --- a/src/operator/lrn-inl.h +++ b/src/operator/lrn-inl.h @@ -98,7 +98,7 @@ template Operator *CreateOp(LRNParam param); #if DMLC_USE_CXX11 -class LocalResponseNormProp : public OperatorProperty { +class LocalResponseNormProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -173,9 +173,6 @@ class LocalResponseNormProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - LRNParam param_; }; // LocalResponseNormProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 8c341bada778..299dbde0ed47 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -8,9 +8,12 @@ #ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ #define MXNET_OPERATOR_OPERATOR_COMMON_H_ +#include #include #include #include +#include +#include #include namespace mxnet { @@ -93,6 +96,43 @@ struct InferShapeError { } #endif +#if DMLC_USE_CXX11 +template +class ParamOperatorProperty : public OperatorProperty { + public: + ParamOperatorProperty() {} + explicit ParamOperatorProperty(Param param) : param_(param) {} + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + std::string value = param_.PrintJson(); + writer->WriteObjectKeyValue("param", value); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + std::string value; + helper.DeclareField("param", &value); + helper.ReadAllFields(reader); + param_.LoadJson(value); + } + inline bool operator==(const ParamOperatorProperty& other) const { + return param_ == other.param_; + } + protected: + Param param_; +}; + +class NoParamOperatorProperty : public OperatorProperty { + public: + inline void Save(dmlc::JSONWriter *writer) const { + } + inline void Load(dmlc::JSONReader *reader) { + } + inline bool operator==(const NoParamOperatorProperty& other) const { + return true; + } +}; +#endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/param.h b/src/operator/param.h index f0ce5886e2fb..9b08c197a160 100644 --- a/src/operator/param.h +++ b/src/operator/param.h @@ -71,4 +71,3 @@ struct Param { } // namespace mxnet #endif // MXNET_OPERATOR_PARAM_H_ - diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h index c13e1f70b6a6..50aee978d34b 100644 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -154,7 +154,7 @@ Operator* CreateOp(PoolingParam param); #if DMLC_USE_CXX11 -class PoolingProp : public OperatorProperty { +class PoolingProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -209,9 +209,6 @@ class PoolingProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - PoolingParam param_; }; // class PoolingProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/reshape-inl.h b/src/operator/reshape-inl.h index d992c21effcb..69a1e5e73143 100644 --- a/src/operator/reshape-inl.h +++ b/src/operator/reshape-inl.h @@ -83,11 +83,11 @@ template Operator* CreateOp(); #if DMLC_USE_CXX11 -class ReshapeProp : public OperatorProperty { +class ReshapeProp : public ParamOperatorProperty { public: ReshapeProp() {} - explicit ReshapeProp(ReshapeParam param) : param_(param) {} + explicit ReshapeProp(ReshapeParam param) : ParamOperatorProperty(param) {} void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -140,9 +140,6 @@ class ReshapeProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ReshapeParam param_; }; // class ReshapeProp class FlattenProp : public ReshapeProp { diff --git a/src/operator/softmax-inl.h b/src/operator/softmax-inl.h index ea0114217cac..b90b2b7343c8 100644 --- a/src/operator/softmax-inl.h +++ b/src/operator/softmax-inl.h @@ -83,7 +83,7 @@ template Operator* CreateOp(SoftmaxParam param); #if DMLC_USE_CXX11 -class SoftmaxProp : public OperatorProperty { +class SoftmaxProp : public ParamOperatorProperty { public: std::vector ListArguments() const override { return {"data", "label"}; @@ -138,9 +138,6 @@ class SoftmaxProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - SoftmaxParam param_; }; // class SoftmaxProp #endif // DMLC_USE_CXX11 diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 53df58dd96f1..4e58908b0fd2 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -291,4 +291,91 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, } } } + +void StaticGraph::DataEntry::Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("source_id", source_id); + writer->WriteObjectKeyValue("index", index); + writer->EndObject(); +} + +void StaticGraph::DataEntry::Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("source_id", &source_id); + helper.DeclareField("index", &index); + helper.ReadAllFields(reader); +} + +void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + if (op.get() != nullptr) { + writer->WriteObjectKeyValue("op_type", op.get()->TypeString()); + std::ostringstream os; + dmlc::JSONWriter subWriter(&os); + subWriter.BeginObject(); + subWriter.WriteObjectKeyValue("op", *(op.get())); + subWriter.EndObject(); + writer->WriteObjectKeyValue("op", os.str()); + } else { + std::string jsonNull = "null"; + writer->WriteObjectKeyValue("op_type", jsonNull); + writer->WriteObjectKeyValue("op", jsonNull); + } + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("inputs", inputs); + writer->WriteObjectKeyValue("backward_source_id", backward_source_id); + writer->EndObject(); +} + +void StaticGraph::Node::Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper firstHelper; + std::string op_type_str; + firstHelper.DeclareField("op_type", &op_type_str); + std::string op_str; + firstHelper.DeclareField("op", &op_str); + firstHelper.DeclareField("name", &name); + firstHelper.DeclareField("inputs", &inputs); + firstHelper.DeclareField("backward_source_id", &backward_source_id); + firstHelper.ReadAllFields(reader); + if (op_type_str != "null") { + dmlc::JSONObjectReadHelper secondHelper; + std::istringstream iss(op_str); + dmlc::JSONReader subReader(&iss); + op.reset(OperatorProperty::Create(op_type_str.c_str())); + secondHelper.DeclareField("op", op.get()); + secondHelper.ReadAllFields(reader); + } else { + op.reset(nullptr); + } +} + +void StaticGraph::Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", heads); + writer->EndObject(); +} + +void StaticGraph::Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("nodes", &nodes); + helper.DeclareField("arg_nodes", &arg_nodes); + helper.DeclareField("heads", &heads); + helper.ReadAllFields(reader); +} + +void StaticGraph::Load(const std::string& json) { + std::istringstream is(json); + dmlc::JSONReader reader(&is); + reader.Read(this); +} + +void StaticGraph::Save(std::string* json) const { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(*this); + *json = os.str(); +} + } // namespace mxnet