Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

refactor as discussed #15

Merged
merged 1 commit into from
Aug 15, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ endif
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connect_op_cpu.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o fully_connect_op_gpu.o
CUOBJ += narray_op_gpu.o fully_connected_gpu.o
endif

.PHONY: clean all test lint doc
Expand All @@ -77,20 +77,16 @@ $(DMLC_CORE)/libdmlc.a:

storage.o: src/storage/storage.cc
engine.o: src/dag_engine/simple_engine.cc
#engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h
narray.o: src/narray/narray.cc
narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h
narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h
static_operator.o: src/operator/static_operator/static_operator.cc
static_operator_cpu.o: src/operator/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/operator/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
static_graph.o : src/symbol/static_graph.cc
registry.o: src/registry.cc
c_api.o: src/c_api.cc
operator.o: src/operator/static_operator_wrapper.cc
fully_connect_op_cpu.o: src/operator/static_operator/fully_connect_op.cc
fully_connect_op_gpu.o: src/operator/static_operator/fully_connect_op.cu
fully_connected_cpu.o: src/operator/fully_connected.cc
fully_connected_gpu.o: src/operator/fully_connected.cu


lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
Expand Down
12 changes: 6 additions & 6 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,19 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
/*!
* \brief create Symbol by wrapping AtomicSymbol
* \brief Create an AtomicSymbol.
* \param creator the AtomicSymbolCreator
* \param num_param the number of parameters
* \param keys the keys to the params
* \param vals the vals of the params
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
#ifndef MXNET_NARRAY_H_
#define MXNET_NARRAY_H_

#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <memory>
Expand Down
252 changes: 199 additions & 53 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
/*!
* Copyright (c) 2015 by Contributors
* \file operator.h
* \brief operator interface of mxnet
* \brief Operator interface of mxnet.
* \author Naiyan Wang
*/
#ifndef MXNET_OPERATOR_H_
#define MXNET_OPERATOR_H_
// this file will be seen by cuda, no c++11 for now

#include <dmlc/base.h>
#include <vector>
#include <string>
#include <utility>
#include "./base.h"
#if DMLC_USE_CXX11
#include "./narray.h"
#include "./dag_engine.h"
#endif
#include "./symbolic.h"

namespace mxnet {
/*! \brief option to pass into the forward function */
Expand All @@ -38,21 +35,25 @@ enum OpReqType {
/*! \brief add to the provided space */
kAddTo
};

/*!
* \brief StaticOperator interface
* StaticOperator is a stateful object that can be used to call forward and backprop
*
* \brief Operator interface.
* Operator defins basic operation unit of optimized computation graph in mxnet.
* This interface relies on pre-allocated memory in TBlob, the caller need to set
* the memory region in TBlob correctly before calling Forward and Backward
* the memory region in TBlob correctly before calling Forward and Backward.
*
* Operator is generated by OperatorProperty.
* To add new operator(aka. layers of neural nets) to mxnet, developer need to create
* a new OperatorProperty and its corresponding Operator.
*
* \sa TBlob, TShape
* \sa TBlob, TShape, OperatorProperty
*/
class StaticOperator {
class Operator {
public:
/*! \brief destructor */
virtual ~StaticOperator() {}
virtual ~Operator() {}
/*!
* \brief perform a forward operation of StaticOperator, save the output to TBlob.
* \brief perform a forward operation of Operator, save the output to TBlob.
* \param opt option on Forward such as whether this is training phase.
* \param ctx runtime context
* \param in_data array of input data, it is const
Expand All @@ -69,7 +70,7 @@ class StaticOperator {
/*!
* \brief Perform a backward Operation, write gradient to the in_grad.
* \param ctx runtime context
* \param out_grad the gradient value we get from output of the StaticOperator
* \param out_grad the gradient value we get from output of the Operator
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
Expand All @@ -85,53 +86,198 @@ class StaticOperator {
};

#if DMLC_USE_CXX11
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11
// OperatorProperty allows C++11, while Operator do not rely on it.
/*!
* \brief Operator interface.
* Operator is an object can have Forward and Backward function.
* \brief OperatorProperty is a object that stores all information about Operator.
* It also contains method to generate context(device) specific operators.
*
* It can be created from
* It also contains various functions that can be optimally overriden to
* provide optimization chance for computation engine.
*/
class Operator {
class OperatorProperty {
public:
/*! \brief destructor */
virtual ~Operator() {}
/*!
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
* \brief virtual destructor
*/
virtual ~OperatorProperty() {}
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
*/
virtual std::vector<std::string> ListArguments() const {
return {"data"};
}
/*!
* \brief Get name of return values of Operator
* \return name of return values.
*/
virtual void Forward() = 0;
virtual std::vector<std::string> ListReturns() const {
return {"output"};
}
/*! \return number of outputs of the Operator */
virtual int NumReturns() const {
return 1;
}
/*!
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NArrays specified by grad_in_args_store will be updated accordingly.
* \brief Set the parameters of the Operator.
* \param name parameter name
* \param val string for the configuration
*/
virtual void Backward() = 0;
/*! \return get array of heads in the operator */
virtual const std::vector<NArray> &head() const = 0;
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* \param ctx the context of binding.
* \param symbol the symbol that specifies the output of Forward pass.
* \param in_args the NArray that stores the input arguments to the symbol.
* \param grad_in_args_store NArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
*/
static Operator *Bind(Symbol symbol,
Context ctx,
const std::vector<NArray> &in_args,
const std::vector<NArray> &grad_in_args_store,
const std::vector<OpReqType> &grad_req_type);
}; // class operator
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const = 0;
/*!
* \brief Copy this OperatorProperty.
* \return a pointer to the copied OperatorProperty
*/
virtual OperatorProperty* Copy() const = 0;
/*!
* \brief Create a Operator on specific context
*/
virtual Operator* CreateOperator(Context ctx) const = 0;
/*!
* \brief return the type string of the Operator
* subclasses override this function.
*/
virtual std::string TypeString() const = 0;
/*!
* \brief Declare the input requirement of Backward pass.
*
* Only the returned list of variables will be used in Backward.
* This function is used for memory optimization.
* It is adviced to override and only return what is actually needed.
* If this function is not overriden, all the variables will be valid in Backward.
*
* \code
* // The following code declares Backward need out_grad[0], in_data[0],in_data[1]
* vector<int> BackwardInputs(const vector<int> &out_grad,
* const vector<int> &in_data,
* const vector<int> &out_data) const {
* return {out_grad[0], in_data[0], in_data[1]};
* }
* \endcode
* \param out_grad gradient of outputs in backward pass.
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \return an integer vector indicating the input requirments
* \sa BackwardInputs
*/
virtual std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
// By default requires to see all the things.
// remember to override this function to get a better performance.
std::vector<int> ret = out_grad;
ret.insert(ret.end(), in_data.begin(), in_data.end());
ret.insert(ret.end(), out_data.begin(), out_data.end());
return ret;
}
/*!
* \brief Get possible forward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* \code
* // The following code says out_data[0] can share data with in_data[0]
* vector<pair<int,int> > ForwardInplaceOption(const vector<int> &in_data,
* const vector<int> &out_data) const {
* return {{out_data[0], in_data[0]}};
* }
* \endcode
* \return list of pair of integers taken from the inputs vector,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, int> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
return std::vector<std::pair<int, int> >();
}
/*!
* \brief Get possible backward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* \code
* // The following code says in_grad[0] can share data with in_data[0]
* vector<pair<int,int> > BackwardInplaceOption(
* const std::vector<int> &out_grad,
* const std::vector<int> &in_data,
* const std::vector<int> &out_data,
* const std::vector<int> &in_grad) const {
* return {in_grad[0], in_data[0]}};
* }
* \endcode
* \return list of pair of integers taken from the inputs vector,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, int> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<int> &in_grad) const {
return std::vector<std::pair<int, int> >();
}
/*!
* \brief Get Backward Input Dependency for generic types of data.
* Normally T can be pointer of Symbol::DataEntry, or NArray.
* This function will select the result list of T according to DeclareBackwardDependency.
*
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \param out_grad gradient of outputs in backward pass.
* \tparam T the generic type parameter.
* \return vector of inputs the Backward Operation depends on.
* \sa DeclareBackwardDependency
*/
template<typename T>
inline std::vector<T> BackwardInputs(const std::vector<T> &in_data,
const std::vector<T> &out_data,
const std::vector<T> &out_grad) const {
int cnt = 0;
std::vector<T> all_vec;
std::vector<int> in_data_idx, out_data_idx, out_grad_idx;
for (size_t i = 0; i < in_data.size(); ++i) {
in_data_idx.push_back(cnt++);
all_vec.push_back(in_data[i]);
}
for (size_t i = 0; i < out_data.size(); ++i) {
out_data_idx.push_back(cnt++);
all_vec.push_back(out_data[i]);
}
for (size_t i = 0; i < out_grad.size(); ++i) {
out_grad_idx.push_back(cnt++);
all_vec.push_back(out_data[i]);
}
std::vector<int> ret_idx = this->DeclareBackwardDependency(
in_data_idx, out_data_idx, out_grad_idx);
std::vector<T> ret;
for (size_t i = 0; i < ret_idx.size(); ++i) {
ret.push_back(all_vec[ret_idx[i]]);
}
return ret;
}
/*!
* \brief create OperatorProperty
* \param type_name the type string of the OperatorProperty
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);
};
#endif
} // namespace mxnet
#endif // MXNET_OPERATOR_H_
Loading