Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor as discussed
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 15, 2015
1 parent b1127a7 commit 3626c34
Show file tree
Hide file tree
Showing 25 changed files with 579 additions and 931 deletions.
12 changes: 4 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ endif
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connect_op_cpu.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o fully_connect_op_gpu.o
CUOBJ += narray_op_gpu.o fully_connected_gpu.o
endif

.PHONY: clean all test lint doc
Expand All @@ -77,20 +77,16 @@ $(DMLC_CORE)/libdmlc.a:

storage.o: src/storage/storage.cc
engine.o: src/dag_engine/simple_engine.cc
#engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h
narray.o: src/narray/narray.cc
narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h
narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h
static_operator.o: src/operator/static_operator/static_operator.cc
static_operator_cpu.o: src/operator/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/operator/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
static_graph.o : src/symbol/static_graph.cc
registry.o: src/registry.cc
c_api.o: src/c_api.cc
operator.o: src/operator/static_operator_wrapper.cc
fully_connect_op_cpu.o: src/operator/static_operator/fully_connect_op.cc
fully_connect_op_gpu.o: src/operator/static_operator/fully_connect_op.cu
fully_connected_cpu.o: src/operator/fully_connected.cc
fully_connected_gpu.o: src/operator/fully_connected.cu


lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
Expand Down
12 changes: 6 additions & 6 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,19 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
/*!
* \brief create Symbol by wrapping AtomicSymbol
* \brief Create an AtomicSymbol.
* \param creator the AtomicSymbolCreator
* \param num_param the number of parameters
* \param keys the keys to the params
* \param vals the vals of the params
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
#ifndef MXNET_NARRAY_H_
#define MXNET_NARRAY_H_

#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <memory>
Expand Down
252 changes: 199 additions & 53 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
/*!
* Copyright (c) 2015 by Contributors
* \file operator.h
* \brief operator interface of mxnet
* \brief Operator interface of mxnet.
* \author Naiyan Wang
*/
#ifndef MXNET_OPERATOR_H_
#define MXNET_OPERATOR_H_
// this file will be seen by cuda, no c++11 for now

#include <dmlc/base.h>
#include <vector>
#include <string>
#include <utility>
#include "./base.h"
#if DMLC_USE_CXX11
#include "./narray.h"
#include "./dag_engine.h"
#endif
#include "./symbolic.h"

namespace mxnet {
/*! \brief option to pass into the forward function */
Expand All @@ -38,21 +35,25 @@ enum OpReqType {
/*! \brief add to the provided space */
kAddTo
};

/*!
* \brief StaticOperator interface
* StaticOperator is a stateful object that can be used to call forward and backprop
*
* \brief Operator interface.
* Operator defins basic operation unit of optimized computation graph in mxnet.
* This interface relies on pre-allocated memory in TBlob, the caller need to set
* the memory region in TBlob correctly before calling Forward and Backward
* the memory region in TBlob correctly before calling Forward and Backward.
*
* Operator is generated by OperatorProperty.
* To add new operator(aka. layers of neural nets) to mxnet, developer need to create
* a new OperatorProperty and its corresponding Operator.
*
* \sa TBlob, TShape
* \sa TBlob, TShape, OperatorProperty
*/
class StaticOperator {
class Operator {
public:
/*! \brief destructor */
virtual ~StaticOperator() {}
virtual ~Operator() {}
/*!
* \brief perform a forward operation of StaticOperator, save the output to TBlob.
* \brief perform a forward operation of Operator, save the output to TBlob.
* \param opt option on Forward such as whether this is training phase.
* \param ctx runtime context
* \param in_data array of input data, it is const
Expand All @@ -69,7 +70,7 @@ class StaticOperator {
/*!
* \brief Perform a backward Operation, write gradient to the in_grad.
* \param ctx runtime context
* \param out_grad the gradient value we get from output of the StaticOperator
* \param out_grad the gradient value we get from output of the Operator
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
Expand All @@ -85,53 +86,198 @@ class StaticOperator {
};

#if DMLC_USE_CXX11
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11
// OperatorProperty allows C++11, while Operator do not rely on it.
/*!
* \brief Operator interface.
* Operator is an object can have Forward and Backward function.
* \brief OperatorProperty is a object that stores all information about Operator.
* It also contains method to generate context(device) specific operators.
*
* It can be created from
* It also contains various functions that can be optimally overriden to
* provide optimization chance for computation engine.
*/
class Operator {
class OperatorProperty {
public:
/*! \brief destructor */
virtual ~Operator() {}
/*!
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
* \brief virtual destructor
*/
virtual ~OperatorProperty() {}
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
*/
virtual std::vector<std::string> ListArguments() const {
return {"data"};
}
/*!
* \brief Get name of return values of Operator
* \return name of return values.
*/
virtual void Forward() = 0;
virtual std::vector<std::string> ListReturns() const {
return {"output"};
}
/*! \return number of outputs of the Operator */
virtual int NumReturns() const {
return 1;
}
/*!
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NArrays specified by grad_in_args_store will be updated accordingly.
* \brief Set the parameters of the Operator.
* \param name parameter name
* \param val string for the configuration
*/
virtual void Backward() = 0;
/*! \return get array of heads in the operator */
virtual const std::vector<NArray> &head() const = 0;
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* \param ctx the context of binding.
* \param symbol the symbol that specifies the output of Forward pass.
* \param in_args the NArray that stores the input arguments to the symbol.
* \param grad_in_args_store NArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
*/
static Operator *Bind(Symbol symbol,
Context ctx,
const std::vector<NArray> &in_args,
const std::vector<NArray> &grad_in_args_store,
const std::vector<OpReqType> &grad_req_type);
}; // class operator
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const = 0;
/*!
* \brief Copy this OperatorProperty.
* \return a pointer to the copied OperatorProperty
*/
virtual OperatorProperty* Copy() const = 0;
/*!
* \brief Create a Operator on specific context
*/
virtual Operator* CreateOperator(Context ctx) const = 0;
/*!
* \brief return the type string of the Operator
* subclasses override this function.
*/
virtual std::string TypeString() const = 0;
/*!
* \brief Declare the input requirement of Backward pass.
*
* Only the returned list of variables will be used in Backward.
* This function is used for memory optimization.
* It is adviced to override and only return what is actually needed.
* If this function is not overriden, all the variables will be valid in Backward.
*
* \code
* // The following code declares Backward need out_grad[0], in_data[0],in_data[1]
* vector<int> BackwardInputs(const vector<int> &out_grad,
* const vector<int> &in_data,
* const vector<int> &out_data) const {
* return {out_grad[0], in_data[0], in_data[1]};
* }
* \endcode
* \param out_grad gradient of outputs in backward pass.
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \return an integer vector indicating the input requirments
* \sa BackwardInputs
*/
virtual std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
// By default requires to see all the things.
// remember to override this function to get a better performance.
std::vector<int> ret = out_grad;
ret.insert(ret.end(), in_data.begin(), in_data.end());
ret.insert(ret.end(), out_data.begin(), out_data.end());
return ret;
}
/*!
* \brief Get possible forward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* \code
* // The following code says out_data[0] can share data with in_data[0]
* vector<pair<int,int> > ForwardInplaceOption(const vector<int> &in_data,
* const vector<int> &out_data) const {
* return {{out_data[0], in_data[0]}};
* }
* \endcode
* \return list of pair of integers taken from the inputs vector,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, int> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
return std::vector<std::pair<int, int> >();
}
/*!
* \brief Get possible backward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* \code
* // The following code says in_grad[0] can share data with in_data[0]
* vector<pair<int,int> > BackwardInplaceOption(
* const std::vector<int> &out_grad,
* const std::vector<int> &in_data,
* const std::vector<int> &out_data,
* const std::vector<int> &in_grad) const {
* return {in_grad[0], in_data[0]}};
* }
* \endcode
* \return list of pair of integers taken from the inputs vector,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, int> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<int> &in_grad) const {
return std::vector<std::pair<int, int> >();
}
/*!
* \brief Get Backward Input Dependency for generic types of data.
* Normally T can be pointer of Symbol::DataEntry, or NArray.
* This function will select the result list of T according to DeclareBackwardDependency.
*
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \param out_grad gradient of outputs in backward pass.
* \tparam T the generic type parameter.
* \return vector of inputs the Backward Operation depends on.
* \sa DeclareBackwardDependency
*/
template<typename T>
inline std::vector<T> BackwardInputs(const std::vector<T> &in_data,
const std::vector<T> &out_data,
const std::vector<T> &out_grad) const {
int cnt = 0;
std::vector<T> all_vec;
std::vector<int> in_data_idx, out_data_idx, out_grad_idx;
for (size_t i = 0; i < in_data.size(); ++i) {
in_data_idx.push_back(cnt++);
all_vec.push_back(in_data[i]);
}
for (size_t i = 0; i < out_data.size(); ++i) {
out_data_idx.push_back(cnt++);
all_vec.push_back(out_data[i]);
}
for (size_t i = 0; i < out_grad.size(); ++i) {
out_grad_idx.push_back(cnt++);
all_vec.push_back(out_data[i]);
}
std::vector<int> ret_idx = this->DeclareBackwardDependency(
in_data_idx, out_data_idx, out_grad_idx);
std::vector<T> ret;
for (size_t i = 0; i < ret_idx.size(); ++i) {
ret.push_back(all_vec[ret_idx[i]]);
}
return ret;
}
/*!
* \brief create OperatorProperty
* \param type_name the type string of the OperatorProperty
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);
};
#endif
} // namespace mxnet
#endif // MXNET_OPERATOR_H_
Loading

0 comments on commit 3626c34

Please sign in to comment.