diff --git a/dmlc-core b/dmlc-core index 2e2d187efc43..d2003b69edc6 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 2e2d187efc43ee2df1d132c3690169575e830442 +Subproject commit d2003b69edc653698d01f3c33986d241b822b3be diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index c5634cddd942..7441db84638f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -123,10 +123,10 @@ MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, * \param keys the name of the NDArray, optional, can be NULL * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNDArrayListSave(const char* fname, - mx_uint num_args, - NDArrayHandle* args, - const char** keys); +MXNET_DLL int MXNDArraySave(const char* fname, + mx_uint num_args, + NDArrayHandle* args, + const char** keys); /*! * \brief Load list of narray from the file. * \param fname name of the file. @@ -136,11 +136,11 @@ MXNET_DLL int MXNDArrayListSave(const char* fname, * \param out_names the names of returning NDArrays, can be NULL * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNDArrayListLoad(const char* fname, - mx_uint *out_size, - NDArrayHandle** out_arr, - mx_uint *out_name_size, - const char*** out_names); +MXNET_DLL int MXNDArrayLoad(const char* fname, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. * @@ -359,13 +359,33 @@ MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols, SymbolHandle *symbols, SymbolHandle *out); /*! - * \brief Create symbol from config. - * \param cfg configuration string - * \param out created symbol handle + * \brief Load a symbol from a json file. + * \param fname the file name. + * \param out the output symbol. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out); +/*! + * \brief Load a symbol from a json string. + * \param json the json string. + * \param out the output symbol. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out); +/*! + * \brief Save a symbol into a json file. + * \param sym the input symbol. + * \param fname the file name. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname); +/*! + * \brief Save a symbol into a json string + * \param sym the input symbol. + * \param out_json output json string. * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, - SymbolHandle *out); +MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json); /*! * \brief Free the symbol handle. * \param symbol the symbol diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 047f9723916f..26f9306ddc97 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include #include "./base.h" #include "./storage.h" @@ -244,6 +246,24 @@ class NDArray { inline void CheckAndAlloc() const { ptr_->CheckAndAlloc(); } + /*! + * \brief Save list of narray into the file. + * \param fname name of the file. + * \param data the NDArrays to be saved. + * \param keys the name of the NDArray, optional, can be zero length. + */ + static void Save(const std::string& fname, + const std::vector& data, + const std::vector& names); + /*! + * \brief Load list of narray into from the file. + * \param fname name of the file. + * \param data the NDArrays to be loaded + * \param keys the name of the NDArray, if saved in the file. + */ + static void Load(const std::string& fname, + std::vector* data, + std::vector* keys); private: /*! \brief the real data chunk that backs NDArray */ @@ -397,7 +417,6 @@ void SampleUniform(real_t begin, real_t end, NDArray *out); * \param out output NDArray. */ void SampleGaussian(real_t mu, real_t sigma, NDArray *out); - //-------------------------------------------------------------- // The following part are API Registration of NDArray functions. //-------------------------------------------------------------- diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 3700a513a546..72c5f6c28823 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include "./base.h" @@ -150,6 +151,11 @@ class OperatorProperty { * \param kwargs the keyword arguments parameters */ virtual void Init(const std::vector >& kwargs) = 0; + /*! + * \brief Get a map representation of internal parameters. + * This can be used by Init to recover the state of OperatorProperty. + */ + virtual std::map GetParams() const = 0; /*! * \brief Get input arguments of the Operator. * \return vector of arguments. @@ -222,6 +228,7 @@ class OperatorProperty { /*! * \brief return the type string of the Operator * subclasses override this function. + * \return The type string. */ virtual std::string TypeString() const = 0; //-------------------------------------------------------- @@ -390,9 +397,6 @@ class OperatorProperty { * \return a new constructed OperatorProperty */ static OperatorProperty *Create(const char* type_name); - - virtual void Save(dmlc::JSONWriter *writer) const = 0; - virtual void Load(dmlc::JSONReader *reader) = 0; }; /*! \brief typedef the factory function of operator property */ @@ -419,6 +423,19 @@ struct OperatorPropertyReg this->key_var_num_args = key; return *this; } + /*! + * \brief Check if TypeString of the type matches the registered name + */ + inline OperatorPropertyReg& check_name() { + OperatorProperty *p = this->body(); + std::string type = p->TypeString(); + delete p; + CHECK_EQ(this->name, type) + << "Register Name and TypeString mismatch, name=\"" << this->name << "\"," + << " but TypeString=\"" << type <<"\""; + return *this; + } + /*! \brief The key num_args name. */ std::string key_var_num_args; }; @@ -438,10 +455,12 @@ struct OperatorPropertyReg */ #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ static ::mxnet::OperatorProperty* __create__ ## OperatorProperty ## name ## __() { \ - return new OperatorPropertyType; \ + OperatorProperty* ret = new OperatorPropertyType(); \ + return ret; \ } \ DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \ - .set_body(__create__ ## OperatorProperty ## name ## __) + .set_body(__create__ ## OperatorProperty ## name ## __) \ + .check_name() #endif // DMLC_USE_CXX11 } // namespace mxnet diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index e8e942eda534..3a8d5c0b2ca9 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -66,11 +67,25 @@ class StaticGraph { if (source_id == other.source_id) return index < other.index; return source_id < other.source_id; } - - /*! \brief interface for json serialization */ - void Save(dmlc::JSONWriter *writer) const; - /*! \brief interface for json serialization */ - void Load(dmlc::JSONReader *reader); + /*! + * \brief interface for json serialization. + * \param writer the JSON writer to write json into. + */ + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginArray(false); + writer->WriteArrayItem(source_id); + writer->WriteArrayItem(index); + writer->EndArray(); + } + /*! + * \brief interface for json serialization. + * \param reader the JSON reader to read json from. + */ + inline void Load(dmlc::JSONReader *reader) { + std::pair p; + reader->Read(&p); + *this = DataEntry(p.first, p.second); + } }; /*! * \brief Operation Node in static graphs. @@ -131,9 +146,15 @@ class StaticGraph { inline bool is_variable() const { return op == nullptr && !is_backward(); } - /*! \brief interface for json serialization */ + /*! + * \brief interface for json serialization. + * \param writer the JSON writer write json. + */ void Save(dmlc::JSONWriter *writer) const; - /*! \brief interface for json serialization */ + /*! + * \brief interface for json serialization. + * \param reader the JSON read to read json. + */ void Load(dmlc::JSONReader *reader); }; /*! \brief all nodes in the graph */ @@ -142,13 +163,15 @@ class StaticGraph { std::vector arg_nodes; /*! \brief heads outputs of the graph */ std::vector heads; - /*! \brief load static graph from json. TODO: a static creator's better */ - void Load(const std::string& json); - /*! \brief save static graph to json */ - void Save(std::string* json) const; - /*! \brief interface for json serialization */ + /*! + * \brief interface for json serialization. + * \param writer the JSON writer write json. + */ void Save(dmlc::JSONWriter *writer) const; - /*! \brief interface for json serialization */ + /*! + * \brief interface for json serialization. + * \param reader the JSON read to read json. + */ void Load(dmlc::JSONReader *reader); // funtions to help inference in static graph /*! @@ -282,6 +305,12 @@ class Symbol { * \param out_graph the pointer holder of the output graph */ void ToStaticGraph(StaticGraph *out_graph) const; + /*! + * \brief create equivalence of symbol from static graphs. + * This operation will change the content of current symbol. + * \param graph the static graph + */ + void FromStaticGraph(const StaticGraph &graph); /*! * \brief Apply the symbol as a function, compose with arguments * \param args positional arguments for the symbol @@ -303,7 +332,6 @@ class Symbol { * \return the new symbol with gradient graph */ Symbol Grad(const std::vector& wrt) const; - /*! * \brief infer the shapes of outputs and unknown input arguments * \param arg_shapes the shape of input arguments of the operator @@ -335,6 +363,24 @@ class Symbol { std::vector *arg_shapes, std::vector *out_shapes, std::vector *aux_shapes) const; + /*! + * \brief interface for json serialization. + * \param writer the JSON writer write json. + */ + inline void Save(dmlc::JSONWriter *writer) const { + StaticGraph g; + this->ToStaticGraph(&g); + g.Save(writer); + } + /*! + * \brief interface for json serialization. + * \param reader the JSON read to read json. + */ + inline void Load(dmlc::JSONReader *reader) { + StaticGraph g; + g.Load(reader); + this->FromStaticGraph(g); + } /*! * \brief get number of outputs of this symbol * \return number of outputs @@ -351,12 +397,6 @@ class Symbol { * \sa OperatorProperty::Create */ static Symbol Create(OperatorProperty *op); - /*! - * \brief create equivalence of symbol from static graphs - * \param graph the static graph - * \return the created symbol - */ - static Symbol Create(const StaticGraph &graph); /*! * \brief create equivalence of symbol by grouping the symbols together * \param symbols list of symbols @@ -466,4 +506,8 @@ class Executor { const std::vector &aux_states); }; // class operator } // namespace mxnet + +namespace dmlc { +DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true); +} #endif // MXNET_SYMBOLIC_H_ diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index a9634292f9c6..e5d1b5a903a2 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -405,11 +405,16 @@ def load(fname): You can also use pickle to do the job if you only work on python. The advantage of load/save is the file is language agnostic. This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) Parameters ---------- fname : str - The name of the file + The name of the file.Can be S3 or HDFS address (remember built with S3 support). + Example of fname: + - s3://my-bucket/path/my-s3-ndarray + - hdfs://my-bucket/path/my-hdfs-ndarray + - /path-to/my-local-ndarray Returns ------- @@ -422,11 +427,11 @@ def load(fname): out_name_size = mx_uint() handles = ctypes.POINTER(NDArrayHandle)() names = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXNDArrayListLoad(c_str(fname), - ctypes.byref(out_size), - ctypes.byref(handles), - ctypes.byref(out_name_size), - ctypes.byref(names))) + check_call(_LIB.MXNDArrayLoad(c_str(fname), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) if out_name_size.value == 0: return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] else: @@ -441,11 +446,16 @@ def save(fname, data): You can also use pickle to do the job if you only work on python. The advantage of load/save is the file is language agnostic. This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) Parameters ---------- fname : str - The name of the file + The name of the file.Can be S3 or HDFS address (remember built with S3 support). + Example of fname: + - s3://my-bucket/path/my-s3-ndarray + - hdfs://my-bucket/path/my-hdfs-ndarray + - /path-to/my-local-ndarray data : list of NDArray or dict of str to NDArray The data to be saved. @@ -467,10 +477,10 @@ def save(fname, data): raise TypeError('save only accept dict str->NDArray or list of NDArray') handles.append(val.handle) keys = None - check_call(_LIB.MXNDArrayListSave(c_str(fname), - len(handles), - c_array(NDArrayHandle, handles), - keys)) + check_call(_LIB.MXNDArraySave(c_str(fname), + len(handles), + c_array(NDArrayHandle, handles), + keys)) # pylint: disable=too-many-locals, invalid-name diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 179c4d6ca214..b5f93a0999d1 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -73,6 +73,22 @@ def __deepcopy__(self): ctypes.byref(handle))) return Symbol(handle) + def __getstate__(self): + this = self.__dict__.copy() + handle = this['handle'] + if handle is not None: + this['handle'] = self.tojson() + return this + + def __setstate__(self, state): + handle = state['handle'] + if handle is not None: + json_str = handle + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) + state['handle'] = handle + self.__dict__.update(state) + def __call__(self, *args, **kwargs): """Invoke symbol as function on inputs. @@ -280,6 +296,42 @@ def debug_str(self): self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) + + def save(self, fname): + """Save symbol into file. + + You can also use pickle to do the job if you only work on python. + The advantage of load/save is the file is language agnostic. + This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) + + Parameters + ---------- + fname : str + The name of the file + - s3://my-bucket/path/my-s3-symbol + - hdfs://my-bucket/path/my-hdfs-symbol + - /path-to/my-local-symbol + + See Also + -------- + symbol.load : Used to load symbol from file. + """ + if not isinstance(fname, string_types): + raise TypeError('fname need to be string') + check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname))) + + def tojson(self): + """Save symbol into a JSON string. + + See Also + -------- + symbol.load_json : Used to load symbol from JSON string. + """ + json_str = ctypes.c_char_p() + check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) + return py_str(json_str.value) + @staticmethod def _get_ndarray_handle(arg_key, args, arg_names, allow_missing): """Helper function to get ndarray handles from various inputs. @@ -550,6 +602,62 @@ def Group(symbols): return Symbol(handle) +def load(fname): + """Load symbol from a JSON file. + + You can also use pickle to do the job if you only work on python. + The advantage of load/save is the file is language agnostic. + This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) + + Parameters + ---------- + fname : str + The name of the file + - s3://my-bucket/path/my-s3-symbol + - hdfs://my-bucket/path/my-hdfs-symbol + - /path-to/my-local-symbol + + Returns + ------- + sym : Symbol + The loaded symbol. + + See Also + -------- + Symbol.save : Used to save symbol into file. + """ + if not isinstance(fname, string_types): + raise TypeError('fname need to be string') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) + return Symbol(handle) + + +def load_json(json_str): + """Load symbol from json string. + + Parameters + ---------- + json_str : str + A json string. + + Returns + ------- + sym : Symbol + The loaded symbol. + + See Also + -------- + Symbol.tojson : Used to save symbol into json string. + """ + if not isinstance(json_str, string_types): + raise TypeError('fname need to be string') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) + return Symbol(handle) + + def _make_atomic_symbol_function(handle): """Create an atomic symbol function by handle and funciton name.""" name = ctypes.c_char_p() diff --git a/src/c_api.cc b/src/c_api.cc index f60ae272e076..f1622c905d5b 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -267,12 +267,11 @@ int MXNDArrayWaitToWrite(NDArrayHandle handle) { API_END(); } -const uint64_t kMXAPINDArrayListMagic = 0x112; -int MXNDArrayListSave(const char* fname, - mx_uint num_args, - NDArrayHandle* args, - const char** keys) { +int MXNDArraySave(const char* fname, + mx_uint num_args, + NDArrayHandle* args, + const char** keys) { API_BEGIN(); std::vector data(num_args); std::vector names; @@ -285,39 +284,21 @@ int MXNDArrayListSave(const char* fname, names[i] = keys[i]; } } - std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); - uint64_t header = kMXAPINDArrayListMagic, reserved = 0; - fo->Write(&header, sizeof(header)); - fo->Write(&reserved, sizeof(reserved)); - fo->Write(data); - fo->Write(names); + mxnet::NDArray::Save(fname, data, names); API_END(); } -int MXNDArrayListLoad(const char* fname, - mx_uint *out_size, - NDArrayHandle** out_arr, - mx_uint *out_name_size, - const char*** out_names) { +int MXNDArrayLoad(const char* fname, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); ret->ret_vec_str.clear(); API_BEGIN(); std::vector data; std::vector &names = ret->ret_vec_str; - std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); - uint64_t header, reserved; - CHECK(fi->Read(&header)) - << "Invalid NDArray file format"; - CHECK(fi->Read(&reserved)) - << "Invalid NDArray file format"; - CHECK(header == kMXAPINDArrayListMagic) - << "Invalid NDArray file format"; - CHECK(fi->Read(&data)) - << "Invalid NDArray file format"; - CHECK(fi->Read(&names)) - << "Invalid NDArray file format"; - CHECK(names.size() == 0 || names.size() == data.size()) - << "Invalid NDArray file format"; + mxnet::NDArray::Load(fname, &data, &names); ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { NDArray *ptr = new NDArray(); @@ -521,6 +502,54 @@ int MXSymbolCreateGroup(mx_uint num_symbols, API_END_HANDLE_ERROR(delete s); } +int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); + dmlc::istream is(fi.get()); + dmlc::JSONReader reader(&is); + s->Load(&reader); + // reset file pointer + is.set_stream(nullptr); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + std::string buf(json); + std::istringstream is(buf); + dmlc::JSONReader reader(&is); + s->Load(&reader); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) { + Symbol *s = static_cast(symbol); + API_BEGIN(); + std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); + dmlc::ostream os(fo.get()); + dmlc::JSONWriter writer(&os); + s->Save(&writer); + // reset file pointer, force flush + os.set_stream(nullptr); + API_END(); +} + +int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { + Symbol *s = static_cast(symbol); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + s->Save(&writer); + ret->ret_str = os.str(); + *out_json = ret->ret_str.c_str(); + API_END(); +} + int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); delete static_cast(symbol); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 920fe1bed709..21edc44580f7 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -3,6 +3,7 @@ * \file ndarray.cc * \brief ndarry module of mxnet */ +#include #include #include #include @@ -403,6 +404,39 @@ bool NDArray::Load(dmlc::Stream *strm) { } } + +const uint64_t kMXAPINDArrayListMagic = 0x112; + +void NDArray::Save(const std::string& fname, + const std::vector& data, + const std::vector& names) { + std::unique_ptr fo(dmlc::Stream::Create(fname.c_str(), "w")); + uint64_t header = kMXAPINDArrayListMagic, reserved = 0; + fo->Write(&header, sizeof(header)); + fo->Write(&reserved, sizeof(reserved)); + fo->Write(data); + fo->Write(names); +} + +void NDArray::Load(const std::string& fname, + std::vector* data, + std::vector* keys) { + std::unique_ptr fi(dmlc::Stream::Create(fname.c_str(), "r")); + uint64_t header, reserved; + CHECK(fi->Read(&header)) + << "Invalid NDArray file format"; + CHECK(fi->Read(&reserved)) + << "Invalid NDArray file format"; + CHECK(header == kMXAPINDArrayListMagic) + << "Invalid NDArray file format"; + CHECK(fi->Read(data)) + << "Invalid NDArray file format"; + CHECK(fi->Read(keys)) + << "Invalid NDArray file format"; + CHECK(keys->size() == 0 || keys->size() == data->size()) + << "Invalid NDArray file format"; +} + NDArray NDArray::Copy(Context ctx) const { NDArray ret(shape(), ctx, true); CopyFromTo(*this, &ret); diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 9105c37fd1b2..f344445930b2 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -84,12 +84,16 @@ template Operator* CreateOp(ActivationParam type); #if DMLC_USE_CXX11 -class ActivationProp : public ParamOperatorProperty { +class ActivationProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -139,6 +143,9 @@ class ActivationProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + ActivationParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 3827dbb909e6..625614725282 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -182,12 +182,16 @@ Operator *CreateOp(BatchNormParam param); #if DMLC_USE_CXX11 -class BatchNormProp : public ParamOperatorProperty { +class BatchNormProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -263,6 +267,8 @@ class BatchNormProp : public ParamOperatorProperty { Operator* CreateOperator(Context ctx) const; + private: + BatchNormParam param_; }; // class BatchNormProp #endif // DMLC_USE_CXX11 diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h index 1a69a261d9bb..6929d35acd0b 100644 --- a/src/operator/concat-inl.h +++ b/src/operator/concat-inl.h @@ -163,12 +163,16 @@ template Operator *CreateOp(ConcatParam param); #if DMLC_USE_CXX11 -class ConcatProp : public ParamOperatorProperty { +class ConcatProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + std::vector ListArguments() const override { std::vector ret; for (int i = 0; i < param_.num_args; ++i) { @@ -223,6 +227,9 @@ class ConcatProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + ConcatParam param_; }; // class ConcatProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index ca7c4e609e73..4d637a23453a 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -57,6 +57,8 @@ class ConvolutionOp : public Operator { public: explicit ConvolutionOp(ConvolutionParam p) { this->param_ = p; + // convert MB to words + param_.workspace = (param_.workspace << 20) / sizeof(real_t); } virtual void Forward(const OpContext &ctx, @@ -262,7 +264,7 @@ template Operator* CreateOp(ConvolutionParam param); #if DMLC_USE_CXX11 -class ConvolutionProp : public ParamOperatorProperty { +class ConvolutionProp : public OperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { @@ -274,8 +276,10 @@ class ConvolutionProp : public ParamOperatorProperty { void Init(const std::vector >& kwargs) override { param_.Init(kwargs); - // convert MB to words - param_.workspace = (param_.workspace << 20) / sizeof(real_t); + } + + std::map GetParams() const override { + return param_.__DICT__(); } bool InferShape(std::vector *in_shape, @@ -358,6 +362,8 @@ class ConvolutionProp : public ParamOperatorProperty { Operator* CreateOperator(Context ctx) const; + private: + ConvolutionParam param_; }; // class ConvolutionProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 57c90d241f8f..2a89e7ee72bc 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -18,6 +18,8 @@ class CuDNNConvolutionOp : public Operator { public: explicit CuDNNConvolutionOp(ConvolutionParam param) { this->param_ = param; + // convert MB to words + param_.workspace = (param_.workspace << 20) / sizeof(real_t); init_cudnn_ = false; // TODO(xxx): fp16 dtype_ = CUDNN_DATA_FLOAT; diff --git a/src/operator/dropout-inl.h b/src/operator/dropout-inl.h index 6afdba146e03..877eab61226b 100644 --- a/src/operator/dropout-inl.h +++ b/src/operator/dropout-inl.h @@ -91,12 +91,16 @@ template Operator *CreateOp(DropoutParam param); #if DMLC_USE_CXX11 -class DropoutProp : public ParamOperatorProperty { +class DropoutProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -160,6 +164,8 @@ class DropoutProp : public ParamOperatorProperty { Operator* CreateOperator(Context ctx) const; + private: + DropoutParam param_; }; // class DropoutProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/elementwise_binary_op-inl.h b/src/operator/elementwise_binary_op-inl.h index 9cd6ef57be11..73c234e96965 100644 --- a/src/operator/elementwise_binary_op-inl.h +++ b/src/operator/elementwise_binary_op-inl.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "./operator_common.h" #include "./mshadow_op.h" @@ -157,12 +158,15 @@ Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type); #if DMLC_USE_CXX11 template -class ElementWiseBinaryOpProp : public NoParamOperatorProperty { +class ElementWiseBinaryOpProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { CHECK_EQ(kwargs.size(), 0) << TypeString() << " do not take any additional keyword arguments besides lhs and rhs"; } + std::map GetParams() const override { + return {}; + } bool InferShape(std::vector *in_shape, std::vector *out_shape, diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h index ebd31f155158..213add51357a 100644 --- a/src/operator/elementwise_sum-inl.h +++ b/src/operator/elementwise_sum-inl.h @@ -121,11 +121,14 @@ template Operator* CreateOp(ElementWiseSumParam param); #if DMLC_USE_CXX11 -class ElementWiseSumProp : public ParamOperatorProperty { +class ElementWiseSumProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } bool InferShape(std::vector *in_shape, std::vector *out_shape, @@ -190,8 +193,10 @@ class ElementWiseSumProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; -}; // class ElementWiseSumProp + private: + ElementWiseSumParam param_; +}; // class ElementWiseSumProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index dfc718596103..6fec9f5d13a5 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -124,7 +124,7 @@ template Operator* CreateOp(FullyConnectedParam param); #if DMLC_USE_CXX11 -class FullyConnectedProp : public ParamOperatorProperty { +class FullyConnectedProp : public OperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { @@ -138,6 +138,10 @@ class FullyConnectedProp : public ParamOperatorProperty { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -170,8 +174,9 @@ class FullyConnectedProp : public ParamOperatorProperty { } std::string TypeString() const override { - return "FullyConnecteded"; + return "FullyConnected"; } + // decalre dependency and inplace optimization options std::vector DeclareBackwardDependency( const std::vector &out_grad, @@ -189,6 +194,9 @@ class FullyConnectedProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + FullyConnectedParam param_; }; // class FullyConnectedSymbol #endif } // namespace op diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 5f4ed83990ea..68cb52eea25f 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -190,12 +190,16 @@ template Operator* CreateOp(LeakyReLUParam type); #if DMLC_USE_CXX11 -class LeakyReLUProp : public ParamOperatorProperty { +class LeakyReLUProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -298,6 +302,9 @@ class LeakyReLUProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + LeakyReLUParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/lrn-inl.h b/src/operator/lrn-inl.h index 7b9326bd6892..93c0e346de42 100644 --- a/src/operator/lrn-inl.h +++ b/src/operator/lrn-inl.h @@ -98,12 +98,16 @@ template Operator *CreateOp(LRNParam param); #if DMLC_USE_CXX11 -class LocalResponseNormProp : public ParamOperatorProperty { +class LocalResponseNormProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -173,6 +177,9 @@ class LocalResponseNormProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + LRNParam param_; }; // LocalResponseNormProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 299dbde0ed47..64714f1e7633 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -95,44 +95,6 @@ struct InferShapeError { return nullptr; \ } #endif - -#if DMLC_USE_CXX11 -template -class ParamOperatorProperty : public OperatorProperty { - public: - ParamOperatorProperty() {} - explicit ParamOperatorProperty(Param param) : param_(param) {} - inline void Save(dmlc::JSONWriter *writer) const { - writer->BeginObject(); - std::string value = param_.PrintJson(); - writer->WriteObjectKeyValue("param", value); - writer->EndObject(); - } - inline void Load(dmlc::JSONReader *reader) { - dmlc::JSONObjectReadHelper helper; - std::string value; - helper.DeclareField("param", &value); - helper.ReadAllFields(reader); - param_.LoadJson(value); - } - inline bool operator==(const ParamOperatorProperty& other) const { - return param_ == other.param_; - } - protected: - Param param_; -}; - -class NoParamOperatorProperty : public OperatorProperty { - public: - inline void Save(dmlc::JSONWriter *writer) const { - } - inline void Load(dmlc::JSONReader *reader) { - } - inline bool operator==(const NoParamOperatorProperty& other) const { - return true; - } -}; -#endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h index 50aee978d34b..1303558e5bae 100644 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -154,12 +154,16 @@ Operator* CreateOp(PoolingParam param); #if DMLC_USE_CXX11 -class PoolingProp : public ParamOperatorProperty { +class PoolingProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -209,6 +213,9 @@ class PoolingProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + PoolingParam param_; }; // class PoolingProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/reshape-inl.h b/src/operator/reshape-inl.h index 69a1e5e73143..a227ec93947f 100644 --- a/src/operator/reshape-inl.h +++ b/src/operator/reshape-inl.h @@ -83,18 +83,16 @@ template Operator* CreateOp(); #if DMLC_USE_CXX11 -class ReshapeProp : public ParamOperatorProperty { +class ReshapeProp : public OperatorProperty { public: ReshapeProp() {} - explicit ReshapeProp(ReshapeParam param) : ParamOperatorProperty(param) {} - void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } - std::string TypeString() const override { - return "Reshape"; + std::map GetParams() const override { + return param_.__DICT__(); } bool InferShape(std::vector *in_shape, @@ -118,6 +116,10 @@ class ReshapeProp : public ParamOperatorProperty { return ptr; } + std::string TypeString() const override { + return "Reshape"; + } + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, @@ -140,12 +142,19 @@ class ReshapeProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + protected: + ReshapeParam param_; }; // class ReshapeProp class FlattenProp : public ReshapeProp { public: void Init(const std::vector >& kwargs) override {} + std::map GetParams() const override { + return {}; + } + std::string TypeString() const override { return "Flatten"; } diff --git a/src/operator/softmax-inl.h b/src/operator/softmax-inl.h index b90b2b7343c8..87cab0cb4568 100644 --- a/src/operator/softmax-inl.h +++ b/src/operator/softmax-inl.h @@ -83,7 +83,7 @@ template Operator* CreateOp(SoftmaxParam param); #if DMLC_USE_CXX11 -class SoftmaxProp : public ParamOperatorProperty { +class SoftmaxProp : public OperatorProperty { public: std::vector ListArguments() const override { return {"data", "label"}; @@ -93,6 +93,10 @@ class SoftmaxProp : public ParamOperatorProperty { param_.Init(kwargs); } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { @@ -138,6 +142,9 @@ class SoftmaxProp : public ParamOperatorProperty { } Operator* CreateOperator(Context ctx) const; + + private: + SoftmaxParam param_; }; // class SoftmaxProp #endif // DMLC_USE_CXX11 diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 4e58908b0fd2..9afcfec67097 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -292,34 +292,17 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, } } -void StaticGraph::DataEntry::Save(dmlc::JSONWriter *writer) const { - writer->BeginObject(); - writer->WriteObjectKeyValue("source_id", source_id); - writer->WriteObjectKeyValue("index", index); - writer->EndObject(); -} - -void StaticGraph::DataEntry::Load(dmlc::JSONReader *reader) { - dmlc::JSONObjectReadHelper helper; - helper.DeclareField("source_id", &source_id); - helper.DeclareField("index", &index); - helper.ReadAllFields(reader); -} - void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); if (op.get() != nullptr) { - writer->WriteObjectKeyValue("op_type", op.get()->TypeString()); - std::ostringstream os; - dmlc::JSONWriter subWriter(&os); - subWriter.BeginObject(); - subWriter.WriteObjectKeyValue("op", *(op.get())); - subWriter.EndObject(); - writer->WriteObjectKeyValue("op", os.str()); + writer->WriteObjectKeyValue("op", op->TypeString()); + std::map param = op->GetParams(); + writer->WriteObjectKeyValue("param", param); } else { - std::string jsonNull = "null"; - writer->WriteObjectKeyValue("op_type", jsonNull); - writer->WriteObjectKeyValue("op", jsonNull); + std::map empty_param; + std::string json_null = "null"; + writer->WriteObjectKeyValue("op", json_null); + writer->WriteObjectKeyValue("param", empty_param); } writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("inputs", inputs); @@ -328,22 +311,20 @@ void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const { } void StaticGraph::Node::Load(dmlc::JSONReader *reader) { - dmlc::JSONObjectReadHelper firstHelper; + dmlc::JSONObjectReadHelper helper; std::string op_type_str; - firstHelper.DeclareField("op_type", &op_type_str); - std::string op_str; - firstHelper.DeclareField("op", &op_str); - firstHelper.DeclareField("name", &name); - firstHelper.DeclareField("inputs", &inputs); - firstHelper.DeclareField("backward_source_id", &backward_source_id); - firstHelper.ReadAllFields(reader); + std::map param; + helper.DeclareField("op", &op_type_str); + helper.DeclareField("param", ¶m); + helper.DeclareField("name", &name); + helper.DeclareField("inputs", &inputs); + helper.DeclareField("backward_source_id", &backward_source_id); + helper.ReadAllFields(reader); + if (op_type_str != "null") { - dmlc::JSONObjectReadHelper secondHelper; - std::istringstream iss(op_str); - dmlc::JSONReader subReader(&iss); op.reset(OperatorProperty::Create(op_type_str.c_str())); - secondHelper.DeclareField("op", op.get()); - secondHelper.ReadAllFields(reader); + std::vector > vec(param.begin(), param.end()); + op->Init(vec); } else { op.reset(nullptr); } @@ -364,18 +345,4 @@ void StaticGraph::Load(dmlc::JSONReader *reader) { helper.DeclareField("heads", &heads); helper.ReadAllFields(reader); } - -void StaticGraph::Load(const std::string& json) { - std::istringstream is(json); - dmlc::JSONReader reader(&is); - reader.Read(this); -} - -void StaticGraph::Save(std::string* json) const { - std::ostringstream os; - dmlc::JSONWriter writer(&os); - writer.Write(*this); - *json = os.str(); -} - } // namespace mxnet diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 9a0b1e0e997d..972802a6777d 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -555,4 +555,32 @@ void Symbol::ToStaticGraph(StaticGraph *out_graph) const { out_graph->heads.push_back(e); } } + +void Symbol::FromStaticGraph(const StaticGraph &graph) { + std::unordered_map > nodes; + std::vector topo_order = graph.TopoSort(); + // copy ver nodes in topo order + for (uint32_t nid : topo_order) { + auto &gnode = graph.nodes[nid]; + auto sym_node = std::make_shared(); + sym_node->name = gnode.name; + if (gnode.op.get() != nullptr) { + sym_node->op.reset(gnode.op->Copy()); + } + if (gnode.backward_source_id != -1) { + sym_node->backward_source_node = nodes.at(gnode.backward_source_id); + } + for (const StaticGraph::DataEntry& e : gnode.inputs) { + Symbol::DataEntry entry(nodes.at(e.source_id), e.index); + sym_node->inputs.push_back(std::move(entry)); + } + nodes[nid] = sym_node; + } + // generate the heads + heads_.clear(); + for (const StaticGraph::DataEntry& e : graph.heads) { + Symbol::DataEntry entry(nodes.at(e.source_id), e.index); + heads_.push_back(std::move(entry)); + } +} } // namespace mxnet diff --git a/tests/python/common/models.py b/tests/python/common/models.py index d7fb74e4fd1e..71df3f07cf47 100644 --- a/tests/python/common/models.py +++ b/tests/python/common/models.py @@ -8,3 +8,22 @@ def mlp2(): out = mx.symbol.FullyConnected(data=out, name='fc2', num_hidden=10) return out + + +def conv(): + data = mx.symbol.Variable('data') + conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) + bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") + act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") + mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') + + conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2)) + bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") + act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") + mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max') + + fl = mx.symbol.Flatten(data = mp2, name="flatten") + fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10) + softmax = mx.symbol.Softmax(data = fc2, name = 'sm') + return softmax + diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 9e29f0b9ffb5..4f7f7eb1109f 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -1,5 +1,7 @@ +import os import mxnet as mx from common import models +import pickle as pkl def test_symbol_basic(): mlist = [] @@ -9,7 +11,7 @@ def test_symbol_basic(): m.list_outputs() -def test_compose(): +def test_symbol_compose(): data = mx.symbol.Variable('data') net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) net1 = mx.symbol.FullyConnected(data=net1, name='fc2', num_hidden=100) @@ -26,3 +28,30 @@ def test_compose(): print(composed.debug_str()) multi_out = mx.symbol.Group([composed, net1]) assert len(multi_out.list_outputs()) == 2 + + +def test_symbol_pickle(): + mlist = [models.mlp2(), models.conv()] + data = pkl.dumps(mlist) + mlist2 = pkl.loads(data) + for x, y in zip(mlist, mlist2): + assert x.tojson() == y.tojson() + + +def test_symbol_saveload(): + sym = models.mlp2() + fname = 'tmp_sym.json' + sym.save(fname) + print sym.tojson() + data2 = mx.symbol.load(fname) + # save because of order + assert sym.tojson() == data2.tojson() + os.remove(fname) + + +if __name__ == '__main__': + test_symbol_basic() + test_symbol_compose() + test_symbol_saveload() + test_symbol_pickle() +