Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Finish Up JSON serialization of symbol.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 18, 2015
1 parent b74d1f2 commit 3e61d8f
Show file tree
Hide file tree
Showing 28 changed files with 569 additions and 193 deletions.
2 changes: 1 addition & 1 deletion dmlc-core
48 changes: 34 additions & 14 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle,
* \param keys the name of the NDArray, optional, can be NULL
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayListSave(const char* fname,
mx_uint num_args,
NDArrayHandle* args,
const char** keys);
MXNET_DLL int MXNDArraySave(const char* fname,
mx_uint num_args,
NDArrayHandle* args,
const char** keys);
/*!
* \brief Load list of narray from the file.
* \param fname name of the file.
Expand All @@ -136,11 +136,11 @@ MXNET_DLL int MXNDArrayListSave(const char* fname,
* \param out_names the names of returning NDArrays, can be NULL
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayListLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
MXNET_DLL int MXNDArrayLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names);
/*!
* \brief Perform a synchronize copy from a continugous CPU memory region.
*
Expand Down Expand Up @@ -359,13 +359,33 @@ MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
/*!
* \brief Create symbol from config.
* \param cfg configuration string
* \param out created symbol handle
* \brief Load a symbol from a json file.
* \param fname the file name.
* \param out the output symbol.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out);
/*!
* \brief Load a symbol from a json string.
* \param json the json string.
* \param out the output symbol.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out);
/*!
* \brief Save a symbol into a json file.
* \param sym the input symbol.
* \param fname the file name.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname);
/*!
* \brief Save a symbol into a json string
* \param sym the input symbol.
* \param out_json output json string.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
Expand Down
21 changes: 20 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dmlc/registry.h>
#include <vector>
#include <string>
#include <memory>
#include "./base.h"
#include "./storage.h"
Expand Down Expand Up @@ -244,6 +246,24 @@ class NDArray {
inline void CheckAndAlloc() const {
ptr_->CheckAndAlloc();
}
/*!
* \brief Save list of narray into the file.
* \param fname name of the file.
* \param data the NDArrays to be saved.
* \param keys the name of the NDArray, optional, can be zero length.
*/
static void Save(const std::string& fname,
const std::vector<NDArray>& data,
const std::vector<std::string>& names);
/*!
* \brief Load list of narray into from the file.
* \param fname name of the file.
* \param data the NDArrays to be loaded
* \param keys the name of the NDArray, if saved in the file.
*/
static void Load(const std::string& fname,
std::vector<NDArray>* data,
std::vector<std::string>* keys);

private:
/*! \brief the real data chunk that backs NDArray */
Expand Down Expand Up @@ -397,7 +417,6 @@ void SampleUniform(real_t begin, real_t end, NDArray *out);
* \param out output NDArray.
*/
void SampleGaussian(real_t mu, real_t sigma, NDArray *out);

//--------------------------------------------------------------
// The following part are API Registration of NDArray functions.
//--------------------------------------------------------------
Expand Down
29 changes: 24 additions & 5 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "./base.h"
Expand Down Expand Up @@ -150,6 +151,11 @@ class OperatorProperty {
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get a map representation of internal parameters.
* This can be used by Init to recover the state of OperatorProperty.
*/
virtual std::map<std::string, std::string> GetParams() const = 0;
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
Expand Down Expand Up @@ -222,6 +228,7 @@ class OperatorProperty {
/*!
* \brief return the type string of the Operator
* subclasses override this function.
* \return The type string.
*/
virtual std::string TypeString() const = 0;
//--------------------------------------------------------
Expand Down Expand Up @@ -390,9 +397,6 @@ class OperatorProperty {
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);

virtual void Save(dmlc::JSONWriter *writer) const = 0;
virtual void Load(dmlc::JSONReader *reader) = 0;
};

/*! \brief typedef the factory function of operator property */
Expand All @@ -419,6 +423,19 @@ struct OperatorPropertyReg
this->key_var_num_args = key;
return *this;
}
/*!
* \brief Check if TypeString of the type matches the registered name
*/
inline OperatorPropertyReg& check_name() {
OperatorProperty *p = this->body();
std::string type = p->TypeString();
delete p;
CHECK_EQ(this->name, type)
<< "Register Name and TypeString mismatch, name=\"" << this->name << "\","
<< " but TypeString=\"" << type <<"\"";
return *this;
}

/*! \brief The key num_args name. */
std::string key_var_num_args;
};
Expand All @@ -438,10 +455,12 @@ struct OperatorPropertyReg
*/
#define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
static ::mxnet::OperatorProperty* __create__ ## OperatorProperty ## name ## __() { \
return new OperatorPropertyType; \
OperatorProperty* ret = new OperatorPropertyType(); \
return ret; \
} \
DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
.set_body(__create__ ## OperatorProperty ## name ## __)
.set_body(__create__ ## OperatorProperty ## name ## __) \
.check_name()

#endif // DMLC_USE_CXX11
} // namespace mxnet
Expand Down
84 changes: 64 additions & 20 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <algorithm>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include <utility>
#include <functional>
Expand Down Expand Up @@ -66,11 +67,25 @@ class StaticGraph {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}

/*! \brief interface for json serialization */
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
void Load(dmlc::JSONReader *reader);
/*!
* \brief interface for json serialization.
* \param writer the JSON writer to write json into.
*/
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray(false);
writer->WriteArrayItem(source_id);
writer->WriteArrayItem(index);
writer->EndArray();
}
/*!
* \brief interface for json serialization.
* \param reader the JSON reader to read json from.
*/
inline void Load(dmlc::JSONReader *reader) {
std::pair<uint32_t, uint32_t> p;
reader->Read(&p);
*this = DataEntry(p.first, p.second);
}
};
/*!
* \brief Operation Node in static graphs.
Expand Down Expand Up @@ -131,9 +146,15 @@ class StaticGraph {
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
/*! \brief interface for json serialization */
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
void Load(dmlc::JSONReader *reader);
};
/*! \brief all nodes in the graph */
Expand All @@ -142,13 +163,15 @@ class StaticGraph {
std::vector<uint32_t> arg_nodes;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
/*! \brief load static graph from json. TODO: a static creator's better */
void Load(const std::string& json);
/*! \brief save static graph to json */
void Save(std::string* json) const;
/*! \brief interface for json serialization */
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
void Save(dmlc::JSONWriter *writer) const;
/*! \brief interface for json serialization */
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
void Load(dmlc::JSONReader *reader);
// funtions to help inference in static graph
/*!
Expand Down Expand Up @@ -282,6 +305,12 @@ class Symbol {
* \param out_graph the pointer holder of the output graph
*/
void ToStaticGraph(StaticGraph *out_graph) const;
/*!
* \brief create equivalence of symbol from static graphs.
* This operation will change the content of current symbol.
* \param graph the static graph
*/
void FromStaticGraph(const StaticGraph &graph);
/*!
* \brief Apply the symbol as a function, compose with arguments
* \param args positional arguments for the symbol
Expand All @@ -303,7 +332,6 @@ class Symbol {
* \return the new symbol with gradient graph
*/
Symbol Grad(const std::vector<std::string>& wrt) const;

/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param arg_shapes the shape of input arguments of the operator
Expand Down Expand Up @@ -335,6 +363,24 @@ class Symbol {
std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes,
std::vector<TShape> *aux_shapes) const;
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
inline void Save(dmlc::JSONWriter *writer) const {
StaticGraph g;
this->ToStaticGraph(&g);
g.Save(writer);
}
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
inline void Load(dmlc::JSONReader *reader) {
StaticGraph g;
g.Load(reader);
this->FromStaticGraph(g);
}
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand All @@ -351,12 +397,6 @@ class Symbol {
* \sa OperatorProperty::Create
*/
static Symbol Create(OperatorProperty *op);
/*!
* \brief create equivalence of symbol from static graphs
* \param graph the static graph
* \return the created symbol
*/
static Symbol Create(const StaticGraph &graph);
/*!
* \brief create equivalence of symbol by grouping the symbols together
* \param symbols list of symbols
Expand Down Expand Up @@ -466,4 +506,8 @@ class Executor {
const std::vector<NDArray> &aux_states);
}; // class operator
} // namespace mxnet

namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true);
}
#endif // MXNET_SYMBOLIC_H_
Loading

0 comments on commit 3e61d8f

Please sign in to comment.