Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'Iroul-patch-0917'
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 17, 2015
2 parents a9d5227 + fabac55 commit b74d1f2
Show file tree
Hide file tree
Showing 18 changed files with 191 additions and 49 deletions.
4 changes: 4 additions & 0 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define MXNET_OPERATOR_H_

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <vector>
Expand Down Expand Up @@ -389,6 +390,9 @@ class OperatorProperty {
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);

virtual void Save(dmlc::JSONWriter *writer) const = 0;
virtual void Load(dmlc::JSONReader *reader) = 0;
};

/*! \brief typedef the factory function of operator property */
Expand Down
36 changes: 36 additions & 0 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define MXNET_SYMBOLIC_H_

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <algorithm>
#include <vector>
#include <memory>
#include <string>
Expand Down Expand Up @@ -64,6 +66,11 @@ class StaticGraph {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}

/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
};
/*!
* \brief Operation Node in static graphs.
Expand Down Expand Up @@ -95,6 +102,23 @@ class StaticGraph {
int32_t backward_source_id;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}

friend void swap(Node& lhs, Node& rhs) {
std::swap(lhs.op, rhs.op);
std::swap(lhs.name, rhs.name);
std::swap(lhs.inputs, rhs.inputs);
std::swap(lhs.backward_source_id, rhs.backward_source_id);
}
/*! \brief copy constructor in favor of serialization. */
Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr),
name(another.name),
inputs(another.inputs),
backward_source_id(another.backward_source_id) {}

inline Node& operator=(Node another) {
swap(*this, another);
return *this;
}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
Expand All @@ -107,13 +131,25 @@ class StaticGraph {
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
/*! \brief load static graph from json. TODO: a static creator's better */
void Load(const std::string& json);
/*! \brief save static graph to json */
void Save(std::string* json) const;
/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
Expand Down
6 changes: 1 addition & 5 deletions src/operator/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template<typename xpu>
Operator* CreateOp(ActivationParam type);

#if DMLC_USE_CXX11
class ActivationProp : public OperatorProperty {
class ActivationProp : public ParamOperatorProperty<ActivationParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -139,12 +139,8 @@ class ActivationProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ActivationParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_ACTIVATION_INL_H_

4 changes: 1 addition & 3 deletions src/operator/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Operator *CreateOp(BatchNormParam param);


#if DMLC_USE_CXX11
class BatchNormProp : public OperatorProperty {
class BatchNormProp : public ParamOperatorProperty<BatchNormParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -263,8 +263,6 @@ class BatchNormProp : public OperatorProperty {

Operator* CreateOperator(Context ctx) const;

private:
BatchNormParam param_;
}; // class BatchNormProp

#endif // DMLC_USE_CXX11
Expand Down
5 changes: 1 addition & 4 deletions src/operator/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ template<typename xpu>
Operator *CreateOp(ConcatParam param);

#if DMLC_USE_CXX11
class ConcatProp : public OperatorProperty {
class ConcatProp : public ParamOperatorProperty<ConcatParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -223,9 +223,6 @@ class ConcatProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ConcatParam param_;
}; // class ConcatProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
4 changes: 1 addition & 3 deletions src/operator/convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ template<typename xpu>
Operator* CreateOp(ConvolutionParam param);

#if DMLC_USE_CXX11
class ConvolutionProp : public OperatorProperty {
class ConvolutionProp : public ParamOperatorProperty<ConvolutionParam> {
public:
std::vector<std::string> ListArguments() const override {
if (!param_.no_bias) {
Expand Down Expand Up @@ -358,8 +358,6 @@ class ConvolutionProp : public OperatorProperty {

Operator* CreateOperator(Context ctx) const;

private:
ConvolutionParam param_;
}; // class ConvolutionProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
4 changes: 1 addition & 3 deletions src/operator/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ template<typename xpu>
Operator *CreateOp(DropoutParam param);

#if DMLC_USE_CXX11
class DropoutProp : public OperatorProperty {
class DropoutProp : public ParamOperatorProperty<DropoutParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -160,8 +160,6 @@ class DropoutProp : public OperatorProperty {

Operator* CreateOperator(Context ctx) const;

private:
DropoutParam param_;
}; // class DropoutProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
2 changes: 1 addition & 1 deletion src/operator/elementwise_binary_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type);

#if DMLC_USE_CXX11
template<typename ForwardOp>
class ElementWiseBinaryOpProp : public OperatorProperty {
class ElementWiseBinaryOpProp : public NoParamOperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
CHECK_EQ(kwargs.size(), 0)
Expand Down
15 changes: 11 additions & 4 deletions src/operator/elementwise_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ class ElementWiseSumOp : public Operator {
Assign(igrad, req[i], F<mshadow_op::identity>(ograd));
}
}
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("size_", size_);
writer->EndObject();
}
inline void Load(dmlc::JSONReader *reader) {
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("size_", &size_);
helper.ReadAllFields(reader);
}

private:
int size_;
Expand All @@ -111,7 +121,7 @@ template<typename xpu>
Operator* CreateOp(ElementWiseSumParam param);

#if DMLC_USE_CXX11
class ElementWiseSumProp : public OperatorProperty {
class ElementWiseSumProp : public ParamOperatorProperty<ElementWiseSumParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -180,9 +190,6 @@ class ElementWiseSumProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ElementWiseSumParam param_;
}; // class ElementWiseSumProp

#endif // DMLC_USE_CXX11
Expand Down
5 changes: 1 addition & 4 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ template<typename xpu>
Operator* CreateOp(FullyConnectedParam param);

#if DMLC_USE_CXX11
class FullyConnectedProp : public OperatorProperty {
class FullyConnectedProp : public ParamOperatorProperty<FullyConnectedParam> {
public:
std::vector<std::string> ListArguments() const override {
if (!param_.no_bias) {
Expand Down Expand Up @@ -189,9 +189,6 @@ class FullyConnectedProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
FullyConnectedParam param_;
}; // class FullyConnectedSymbol
#endif
} // namespace op
Expand Down
5 changes: 1 addition & 4 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ template<typename xpu>
Operator* CreateOp(LeakyReLUParam type);

#if DMLC_USE_CXX11
class LeakyReLUProp : public OperatorProperty {
class LeakyReLUProp : public ParamOperatorProperty<LeakyReLUParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -298,9 +298,6 @@ class LeakyReLUProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
LeakyReLUParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
5 changes: 1 addition & 4 deletions src/operator/lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ template<typename xpu>
Operator *CreateOp(LRNParam param);

#if DMLC_USE_CXX11
class LocalResponseNormProp : public OperatorProperty {
class LocalResponseNormProp : public ParamOperatorProperty<LRNParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -173,9 +173,6 @@ class LocalResponseNormProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
LRNParam param_;
}; // LocalResponseNormProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
40 changes: 40 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_
#define MXNET_OPERATOR_OPERATOR_COMMON_H_

#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <mxnet/operator.h>
#include <mxnet/base.h>
#include <istream>
#include <ostream>
#include <string>

namespace mxnet {
Expand Down Expand Up @@ -93,6 +96,43 @@ struct InferShapeError {
}
#endif

#if DMLC_USE_CXX11
template<class Param>
class ParamOperatorProperty : public OperatorProperty {
public:
ParamOperatorProperty() {}
explicit ParamOperatorProperty(Param param) : param_(param) {}
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
std::string value = param_.PrintJson();
writer->WriteObjectKeyValue("param", value);
writer->EndObject();
}
inline void Load(dmlc::JSONReader *reader) {
dmlc::JSONObjectReadHelper helper;
std::string value;
helper.DeclareField("param", &value);
helper.ReadAllFields(reader);
param_.LoadJson(value);
}
inline bool operator==(const ParamOperatorProperty<Param>& other) const {
return param_ == other.param_;
}
protected:
Param param_;
};

class NoParamOperatorProperty : public OperatorProperty {
public:
inline void Save(dmlc::JSONWriter *writer) const {
}
inline void Load(dmlc::JSONReader *reader) {
}
inline bool operator==(const NoParamOperatorProperty& other) const {
return true;
}
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_
1 change: 0 additions & 1 deletion src/operator/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,3 @@ struct Param {
} // namespace mxnet

#endif // MXNET_OPERATOR_PARAM_H_

5 changes: 1 addition & 4 deletions src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Operator* CreateOp(PoolingParam param);


#if DMLC_USE_CXX11
class PoolingProp : public OperatorProperty {
class PoolingProp : public ParamOperatorProperty<PoolingParam> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -209,9 +209,6 @@ class PoolingProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
PoolingParam param_;
}; // class PoolingProp
#endif // DMLC_USE_CXX11
} // namespace op
Expand Down
7 changes: 2 additions & 5 deletions src/operator/reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ template<typename xpu>
Operator* CreateOp();

#if DMLC_USE_CXX11
class ReshapeProp : public OperatorProperty {
class ReshapeProp : public ParamOperatorProperty<ReshapeParam> {
public:
ReshapeProp() {}

explicit ReshapeProp(ReshapeParam param) : param_(param) {}
explicit ReshapeProp(ReshapeParam param) : ParamOperatorProperty<ReshapeParam>(param) {}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand Down Expand Up @@ -140,9 +140,6 @@ class ReshapeProp : public OperatorProperty {
}

Operator* CreateOperator(Context ctx) const;

private:
ReshapeParam param_;
}; // class ReshapeProp

class FlattenProp : public ReshapeProp {
Expand Down
Loading

0 comments on commit b74d1f2

Please sign in to comment.