From 38c406ae27af8c92e1dfff5f1ca34df36b7926dc Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 08:54:04 -0700 Subject: [PATCH 1/2] [IO] remove default for data_shape, make it required --- example/cifar10/cifar10.py | 4 ++-- example/cifar10/cifar10_multi_gpus.py | 6 +++--- src/io/image_augmenter.h | 2 -- src/io/iter_batchloader.h | 2 -- src/io/iter_image_recordio.cc | 2 -- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 128484c9d9ca..b99c49ea7423 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -165,7 +165,7 @@ def RandomInit(narray): mean_img="data/cifar/cifar_mean.bin", rand_crop=True, rand_mirror=True, - input_shape=(3,28,28), + data_shape=(3,28,28), batch_size=batch_size, preprocess_threads=1) test_dataiter = mx.io.ImageRecordIter( @@ -173,7 +173,7 @@ def RandomInit(narray): mean_img="data/cifar/cifar_mean.bin", rand_crop=False, rand_mirror=False, - input_shape=(3,28,28), + data_shape=(3,28,28), batch_size=batch_size, preprocess_threads=1) diff --git a/example/cifar10/cifar10_multi_gpus.py b/example/cifar10/cifar10_multi_gpus.py index dfe4a94cd026..e68e6edfc77b 100644 --- a/example/cifar10/cifar10_multi_gpus.py +++ b/example/cifar10/cifar10_multi_gpus.py @@ -1,4 +1,4 @@ -# pylint: skip-file +# Pylint: skip-file import numpy as np import mxnet as mx import copy @@ -153,7 +153,7 @@ def momentum_update(key, grad, weight): rand_crop=True, rand_mirror=True, shuffle=False, - input_shape=(3,28,28), + data_shape=(3,28,28), batch_size=batch_size, preprocess_threads=4, prefetch_buffer=6) @@ -163,7 +163,7 @@ def momentum_update(key, grad, weight): rand_crop=False, rand_mirror=False, shuffle=False, - input_shape=(3,28,28), + data_shape=(3,28,28), batch_size=batch_size, preprocess_threads=4, prefetch_buffer=6) diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index f3b4154425da..ff6906ebfbb4 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -120,9 +120,7 @@ struct ImageAugmentParam : public dmlc::Parameter { .describe("Augmentation Param: Maximum value of illumination variation."); DMLC_DECLARE_FIELD(silent).set_default(true) .describe("Augmentation Param: Whether to print augmentor info."); - index_t data_shape_default[] = {3, 224, 224}; DMLC_DECLARE_FIELD(data_shape) - .set_default(TShape(data_shape_default, data_shape_default + 3)) .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Input shape of the neural net."); } diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 6752db02d658..2a082c57f4ff 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -33,9 +33,7 @@ struct BatchParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(BatchParam) { DMLC_DECLARE_FIELD(batch_size) .describe("Batch Param: Batch size."); - index_t data_shape_default[] = {3, 224, 224}; DMLC_DECLARE_FIELD(data_shape) - .set_default(TShape(data_shape_default, data_shape_default + 3)) .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(label_width).set_default(1) diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index f8aded4defa7..7293defa04c6 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -109,9 +109,7 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("Dataset Param: Path to image record file."); DMLC_DECLARE_FIELD(label_width).set_lower_bound(1).set_default(1) .describe("Dataset Param: How many labels for an image."); - index_t data_shape_default[] = {3, 224, 224}; DMLC_DECLARE_FIELD(data_shape) - .set_default(TShape(data_shape_default, data_shape_default + 3)) .enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(preprocess_threads).set_lower_bound(1).set_default(4) From 7f52e0dfc4bbbc8969187f62cc22d8c82f39800a Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 09:04:14 -0700 Subject: [PATCH 2/2] [SYMBOL] move static graph to internal --- include/mxnet/symbolic.h | 261 ++-------------------------- src/symbol/graph_algorithm.h | 1 + src/symbol/graph_executor.cc | 4 +- src/symbol/graph_executor.h | 3 +- src/symbol/graph_memory_allocator.h | 1 + src/symbol/static_graph.cc | 1 + src/symbol/static_graph.h | 242 ++++++++++++++++++++++++++ src/symbol/symbol.cc | 14 ++ 8 files changed, 282 insertions(+), 245 deletions(-) create mode 100644 src/symbol/static_graph.h diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 3a8d5c0b2ca9..edf3e30cab2b 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -9,15 +9,11 @@ #include #include -#include #include #include #include #include #include -#include -#include -#include #include "./base.h" #include "./ndarray.h" #include "./operator.h" @@ -29,223 +25,16 @@ namespace mxnet { /*! - * \brief StaticGraph is the configuration of computation graphs. - * This is the "configuration file" of mxnet. - * It can be converted to/from Symbol, and can be used to bind to operators. + * \brief Internal data structure used for + * graph serializaion and graph algorithms. */ -class StaticGraph { - public: - /*! \brief represents a data in the graph */ - struct DataEntry { - /*! \brief the source node id in the computation graph */ - uint32_t source_id; - /*! \brief index of output from the source. */ - uint32_t index; - /*! \brief default constructor */ - DataEntry() {} - /*! - * \brief constructor with source and index - * \param source_id source id - * \param index node index - */ - DataEntry(uint32_t source_id, uint32_t index) - : source_id(source_id), index(index) {} - /*! - * \brief compare equality - * \param other the other entry to compare - * \return whether two entries equals to each other - */ - inline bool operator==(const DataEntry &other) const { - return source_id == other.source_id && index == other.index; - } - /*! - * \brief comparator, allows to use map - * \param other the other entry to compare - * \return whether two entries is smaller than the other - */ - inline bool operator<(const DataEntry &other) const { - if (source_id == other.source_id) return index < other.index; - return source_id < other.source_id; - } - /*! - * \brief interface for json serialization. - * \param writer the JSON writer to write json into. - */ - inline void Save(dmlc::JSONWriter *writer) const { - writer->BeginArray(false); - writer->WriteArrayItem(source_id); - writer->WriteArrayItem(index); - writer->EndArray(); - } - /*! - * \brief interface for json serialization. - * \param reader the JSON reader to read json from. - */ - inline void Load(dmlc::JSONReader *reader) { - std::pair p; - reader->Read(&p); - *this = DataEntry(p.first, p.second); - } - }; - /*! - * \brief Operation Node in static graphs. - * There are two types of node, Forward and Backward Node. - * - * - Forward node corresponds to the op.Forward - * - Backward node corresponds to the Backward pass, - * where the corresponding forward node is indicated by backward_source_id. - * The op field in Backward node is nullptr - * - * The reason we explicit support Backward node is to allow special treatment - * such as shape inference and state sharing with Forward pass. - */ - struct Node { - /*! \brief wrapped operator property */ - std::unique_ptr op; - /*! \brief name of the node */ - std::string name; - /*! \brief inputs (node_id, index) for of the nodes*/ - std::vector inputs; - /*! - * \brief If this field is nonnegative, this indicates this - * Node is corresponds to a Backward Operation of Operator. - * backward_source_id will points to the corresponding Forward Node. - * - * For normal node, this field is -1. - * When the node is a Backward node, the op field will be nullptr - */ - int32_t backward_source_id; - /*! \brief default constructor */ - Node() : backward_source_id(-1) {} - - friend void swap(Node& lhs, Node& rhs) { - std::swap(lhs.op, rhs.op); - std::swap(lhs.name, rhs.name); - std::swap(lhs.inputs, rhs.inputs); - std::swap(lhs.backward_source_id, rhs.backward_source_id); - } - /*! \brief copy constructor in favor of serialization. */ - Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr), - name(another.name), - inputs(another.inputs), - backward_source_id(another.backward_source_id) {} - - inline Node& operator=(Node another) { - swap(*this, another); - return *this; - } - /*! \return whether the node is forward op node */ - inline bool is_forward() const { - return op != nullptr; - } - /*! \return whether the node is backward op node */ - inline bool is_backward() const { - return backward_source_id != -1; - } - /*! \return whether the node is variable node */ - inline bool is_variable() const { - return op == nullptr && !is_backward(); - } - /*! - * \brief interface for json serialization. - * \param writer the JSON writer write json. - */ - void Save(dmlc::JSONWriter *writer) const; - /*! - * \brief interface for json serialization. - * \param reader the JSON read to read json. - */ - void Load(dmlc::JSONReader *reader); - }; - /*! \brief all nodes in the graph */ - std::vector nodes; - /*! \brief index of nodes that correspods to arguments */ - std::vector arg_nodes; - /*! \brief heads outputs of the graph */ - std::vector heads; - /*! - * \brief interface for json serialization. - * \param writer the JSON writer write json. - */ - void Save(dmlc::JSONWriter *writer) const; - /*! - * \brief interface for json serialization. - * \param reader the JSON read to read json. - */ - void Load(dmlc::JSONReader *reader); - // funtions to help inference in static graph - /*! - * \brief Perform a topological sort on the graph - * \return a topological order of node indices. - */ - std::vector TopoSort() const; - /*! - * \brief infer the node shapes in the computation graph. - * - * When calling this function, user can setup the shape information known into right position. - * Unknown shape are indicated by shape.ndim() == 0. - * - * \param topo_order The topological order of node index, as created by TopoSort. - * \param node_out_shapes The shapes of the each outputs of nodes in the graph. - * \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph. - * \return if the shape inference is successful, return true, else return false. - */ - bool InferNodeShapes(const std::vector &topo_order, - std::vector > *node_out_shapes, - std::vector > *node_aux_shapes) const; - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by ListArguments - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \param aux_shape the shape of auxiliary states of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. - */ - bool InferShape(std::vector* in_shape, - std::vector* out_shape, - std::vector* aux_shape) const; - /*! - * \brief Add a full backward pass in the static graph. - * This function will add gradient nodes for each heads, - * and add the backward pass to backprop the gradients all - * the way to the arguments. - * - * This will change the nodes field in the StaticGraph, but will not change other fields. - * The head and input of Backward pass will be returned by head_grad_nodes and arg_grads. - * - * \param head_grad_nodes used to store the created head gradient inputs for backward pass. - * \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator - */ - void MakeBackwardPass(std::vector *head_grad_nodes, - std::vector *arg_grads); - - /*! - * \brief create a sum node that aggregates gradient together - * \param grad_source the source of the inputs. - * \return a created ElementWiseSum node - */ - static Node CreateSumNode(const std::vector &grad_source); -}; - +class StaticGraph; /*! * \brief Symbol is used to represent dynamically generated symbolic computation graph. * * This class is used as a tool to generate computation graphs(aka. configuration) of the network. * Symbol is always composite, the head Node is the output node of the symbol. * An atomic symbol can be seen as a special case of the composite symbol with only the head node. - * - * The symbol can be converted from/to StaticGraph, the actual configuration used by mxnet. - * Symbol offers more flexible way to composite nodes than StaticGraph, which makes it good - * tool to generate configurations from language bindings such as python. - * \sa StaticGraph */ class Symbol { public: @@ -297,20 +86,6 @@ class Symbol { */ void Compose(const std::unordered_map& kwargs, const std::string& name); - /*! - * \brief Convert a list of symbols into static graph - * - * The user can go further to call bind function on static graph - * - * \param out_graph the pointer holder of the output graph - */ - void ToStaticGraph(StaticGraph *out_graph) const; - /*! - * \brief create equivalence of symbol from static graphs. - * This operation will change the content of current symbol. - * \param graph the static graph - */ - void FromStaticGraph(const StaticGraph &graph); /*! * \brief Apply the symbol as a function, compose with arguments * \param args positional arguments for the symbol @@ -367,20 +142,12 @@ class Symbol { * \brief interface for json serialization. * \param writer the JSON writer write json. */ - inline void Save(dmlc::JSONWriter *writer) const { - StaticGraph g; - this->ToStaticGraph(&g); - g.Save(writer); - } + void Save(dmlc::JSONWriter *writer) const; /*! * \brief interface for json serialization. * \param reader the JSON read to read json. */ - inline void Load(dmlc::JSONReader *reader) { - StaticGraph g; - g.Load(reader); - this->FromStaticGraph(g); - } + void Load(dmlc::JSONReader *reader); /*! * \brief get number of outputs of this symbol * \return number of outputs @@ -451,6 +218,20 @@ class Symbol { * \return maximum number of duplication factor */ int FindDuplicateArgs(std::unordered_map *out) const; + /*! + * \brief Convert symbol into internal static graph + * + * \param out_graph the pointer holder of the output graph + */ + void ToStaticGraph(StaticGraph *out_graph) const; + /*! + * \brief create equivalence of symbol from static graphs. + * This operation will change the content of current symbol. + * \param graph the static graph + */ + void FromStaticGraph(const StaticGraph &graph); + /*! \brief let static graph know the contents */ + friend class StaticGraph; }; /*! @@ -506,8 +287,4 @@ class Executor { const std::vector &aux_states); }; // class operator } // namespace mxnet - -namespace dmlc { -DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true); -} #endif // MXNET_SYMBOLIC_H_ diff --git a/src/symbol/graph_algorithm.h b/src/symbol/graph_algorithm.h index 021bc0744981..c009e28aca92 100644 --- a/src/symbol/graph_algorithm.h +++ b/src/symbol/graph_algorithm.h @@ -12,6 +12,7 @@ #include #include #include +#include "./static_graph.h" namespace mxnet { namespace graph { diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index cf70826919fa..bcab616668aa 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -239,9 +239,9 @@ GraphExecutor::~GraphExecutor() { Engine::Get()->WaitForAll(); } -void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { +void GraphExecutor::InitGraph(const Symbol &symbol, Context ctx, bool need_backward) { // initialize all internal data structures - symbol.ToStaticGraph(&graph_); + graph_.FromSymbol(symbol); num_forward_nodes_ = graph_.nodes.size(); if (need_backward) { graph_.MakeBackwardPass(&head_grad_nodes_, &arg_grads_); diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index a7dbb90892ee..010b700c62ac 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -11,6 +11,7 @@ #include #include #include +#include "./static_graph.h" #include "./graph_memory_allocator.h" namespace mxnet { @@ -159,7 +160,7 @@ class GraphExecutor : public Executor { */ inline OpExecEntry GetOpExecEntry(uint32_t node_id); // initialize the internal graph structure - void InitGraph(Symbol symbol, Context ctx, bool need_backward); + void InitGraph(const Symbol &symbol, Context ctx, bool need_backward); // initialize internal DataEntryInfo, reference counting void InitDataEntryInfo(const std::vector &in_args, const std::vector &arg_grad_store, diff --git a/src/symbol/graph_memory_allocator.h b/src/symbol/graph_memory_allocator.h index 2ed10f14fcb3..759c79ad5452 100644 --- a/src/symbol/graph_memory_allocator.h +++ b/src/symbol/graph_memory_allocator.h @@ -11,6 +11,7 @@ #include #include #include +#include "./static_graph.h" #include "./graph_algorithm.h" namespace mxnet { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 9afcfec67097..396324c6951e 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -8,6 +8,7 @@ #include #include #include +#include "./static_graph.h" #include "../operator/operator_common.h" namespace mxnet { diff --git a/src/symbol/static_graph.h b/src/symbol/static_graph.h new file mode 100644 index 000000000000..514a8f6d80a0 --- /dev/null +++ b/src/symbol/static_graph.h @@ -0,0 +1,242 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_graph.h + * \brief A memory compact representation of symbolic graph + * Used for serialization, and helper data structure. + * \author Naiyan Wang + */ +#ifndef MXNET_SYMBOL_STATIC_GRAPH_H_ +#define MXNET_SYMBOL_STATIC_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { +/*! + * \brief StaticGraph is the configuration of computation graphs. + * This is the "configuration file" of mxnet. + * It can be converted to/from Symbol, and can be used to bind to operators. + * The symbol can be converted from/to StaticGraph, the actual configuration used by mxnet. + * Symbol offers more flexible way to composite nodes than StaticGraph, which makes it good + * tool to generate configurations from language bindings such as python. + * \sa Symbol + */ +class StaticGraph { + public: + /*! \brief represents a data in the graph */ + struct DataEntry { + /*! \brief the source node id in the computation graph */ + uint32_t source_id; + /*! \brief index of output from the source. */ + uint32_t index; + /*! \brief default constructor */ + DataEntry() {} + /*! + * \brief constructor with source and index + * \param source_id source id + * \param index node index + */ + DataEntry(uint32_t source_id, uint32_t index) + : source_id(source_id), index(index) {} + /*! + * \brief compare equality + * \param other the other entry to compare + * \return whether two entries equals to each other + */ + inline bool operator==(const DataEntry &other) const { + return source_id == other.source_id && index == other.index; + } + /*! + * \brief comparator, allows to use map + * \param other the other entry to compare + * \return whether two entries is smaller than the other + */ + inline bool operator<(const DataEntry &other) const { + if (source_id == other.source_id) return index < other.index; + return source_id < other.source_id; + } + /*! + * \brief interface for json serialization. + * \param writer the JSON writer to write json into. + */ + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginArray(false); + writer->WriteArrayItem(source_id); + writer->WriteArrayItem(index); + writer->EndArray(); + } + /*! + * \brief interface for json serialization. + * \param reader the JSON reader to read json from. + */ + inline void Load(dmlc::JSONReader *reader) { + std::pair p; + reader->Read(&p); + *this = DataEntry(p.first, p.second); + } + }; + /*! + * \brief Operation Node in static graphs. + * There are two types of node, Forward and Backward Node. + * + * - Forward node corresponds to the op.Forward + * - Backward node corresponds to the Backward pass, + * where the corresponding forward node is indicated by backward_source_id. + * The op field in Backward node is nullptr + * + * The reason we explicit support Backward node is to allow special treatment + * such as shape inference and state sharing with Forward pass. + */ + struct Node { + /*! \brief wrapped operator property */ + std::unique_ptr op; + /*! \brief name of the node */ + std::string name; + /*! \brief inputs (node_id, index) for of the nodes*/ + std::vector inputs; + /*! + * \brief If this field is nonnegative, this indicates this + * Node is corresponds to a Backward Operation of Operator. + * backward_source_id will points to the corresponding Forward Node. + * + * For normal node, this field is -1. + * When the node is a Backward node, the op field will be nullptr + */ + int32_t backward_source_id; + /*! \brief default constructor */ + Node() : backward_source_id(-1) {} + + friend void swap(Node& lhs, Node& rhs) { + std::swap(lhs.op, rhs.op); + std::swap(lhs.name, rhs.name); + std::swap(lhs.inputs, rhs.inputs); + std::swap(lhs.backward_source_id, rhs.backward_source_id); + } + /*! \brief copy constructor in favor of serialization. */ + Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr), + name(another.name), + inputs(another.inputs), + backward_source_id(another.backward_source_id) {} + + inline Node& operator=(Node another) { + swap(*this, another); + return *this; + } + /*! \return whether the node is forward op node */ + inline bool is_forward() const { + return op != nullptr; + } + /*! \return whether the node is backward op node */ + inline bool is_backward() const { + return backward_source_id != -1; + } + /*! \return whether the node is variable node */ + inline bool is_variable() const { + return op == nullptr && !is_backward(); + } + /*! + * \brief interface for json serialization. + * \param writer the JSON writer write json. + */ + void Save(dmlc::JSONWriter *writer) const; + /*! + * \brief interface for json serialization. + * \param reader the JSON read to read json. + */ + void Load(dmlc::JSONReader *reader); + }; + /*! \brief all nodes in the graph */ + std::vector nodes; + /*! \brief index of nodes that correspods to arguments */ + std::vector arg_nodes; + /*! \brief heads outputs of the graph */ + std::vector heads; + /*! + * \brief interface for json serialization. + * \param writer the JSON writer write json. + */ + void Save(dmlc::JSONWriter *writer) const; + /*! + * \brief interface for json serialization. + * \param reader the JSON read to read json. + */ + void Load(dmlc::JSONReader *reader); + // funtions to help inference in static graph + /*! + * \brief Perform a topological sort on the graph + * \return a topological order of node indices. + */ + std::vector TopoSort() const; + /*! + * \brief infer the node shapes in the computation graph. + * + * When calling this function, user can setup the shape information known into right position. + * Unknown shape are indicated by shape.ndim() == 0. + * + * \param topo_order The topological order of node index, as created by TopoSort. + * \param node_out_shapes The shapes of the each outputs of nodes in the graph. + * \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph. + * \return if the shape inference is successful, return true, else return false. + */ + bool InferNodeShapes(const std::vector &topo_order, + std::vector > *node_out_shapes, + std::vector > *node_aux_shapes) const; + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \param aux_shape the shape of auxiliary states of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + bool InferShape(std::vector* in_shape, + std::vector* out_shape, + std::vector* aux_shape) const; + /*! + * \brief Add a full backward pass in the static graph. + * This function will add gradient nodes for each heads, + * and add the backward pass to backprop the gradients all + * the way to the arguments. + * + * This will change the nodes field in the StaticGraph, but will not change other fields. + * The head and input of Backward pass will be returned by head_grad_nodes and arg_grads. + * + * \param head_grad_nodes used to store the created head gradient inputs for backward pass. + * \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator + */ + void MakeBackwardPass(std::vector *head_grad_nodes, + std::vector *arg_grads); + /*! + * \brief Convert symbol into static graph. + * \param symbol the symbol to convert from. + */ + inline void FromSymbol(const Symbol &symbol) { + symbol.ToStaticGraph(this); + } + /*! + * \brief create a sum node that aggregates gradient together + * \param grad_source the source of the inputs. + * \return a created ElementWiseSum node + */ + static Node CreateSumNode(const std::vector &grad_source); +}; +} // namespace mxnet + +namespace dmlc { +DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true); +} +#endif // MXNET_SYMBOL_STATIC_GRAPH_H_ diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 972802a6777d..2b923cebc0d9 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -9,6 +9,7 @@ #include #include #include +#include "./static_graph.h" namespace mxnet { /*! @@ -481,6 +482,19 @@ bool Symbol::InferShape(const std::unordered_map& known_arg return g.InferShape(arg_shapes, out_shapes, aux_shapes); } + +void Symbol::Save(dmlc::JSONWriter *writer) const { + StaticGraph g; + this->ToStaticGraph(&g); + g.Save(writer); +} + +void Symbol::Load(dmlc::JSONReader *reader) { + StaticGraph g; + g.Load(reader); + this->FromStaticGraph(g); +} + Symbol Symbol::Create(OperatorProperty *op) { // use special representation for atomic symbol auto node = std::make_shared(op, "");