Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #14 from antinucleon/master
Browse files Browse the repository at this point in the history
new interface
  • Loading branch information
antinucleon committed Aug 14, 2015
2 parents 0cf889c + b1127a7 commit b5f485f
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 247 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ endif

#BIN = test/test_threaded_engine test/api_registry_test
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o
OBJ = storage.o narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connect_op_cpu.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o static_operator_gpu.o fully_connect_op_gpu.o
CUOBJ += narray_op_gpu.o fully_connect_op_gpu.o
endif

.PHONY: clean all test lint doc
Expand Down
33 changes: 0 additions & 33 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,6 @@ typedef mshadow::index_t index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;

/*! \brief option to pass into the forward function */
struct Option {
/*! \brief whether it is training phase*/
int is_train;
};
/*! \brief gradient request type the request can have */
enum GradReqType {
/*! \brief no operation, do not write gradient */
kNullOp = 0,
/*! \brief write gradient to provided space */
kWriteTo = 1,
/*! \brief same as kWriteTo, but provided space is same as space of input-data */
kWriteInplace = 2,
/*! \brief add to the provided space */
kAddTo = 3
};
/*! \brief input argument type of the operator have */
enum ArgType {
/*! \brief data argument */
kDataArg = 0,
/*! \brief weight argument */
kWeightArg = 1,
/*! \brief bias argument */
kBiasArg = 2
};
/*! \brief Property for engine schedule */
enum Property {
/*! \brief Op contains interanl state, won't influence engine schedule */
kContainInteralState = 1,
/*! \brief Op forward require random number, will influence engine schedule */
kForwardRequireRnd = 2,
};

/*! \brief context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
Expand Down
150 changes: 73 additions & 77 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,34 @@
#include <dmlc/base.h>
#include <vector>
#include "./base.h"
#if DMLC_USE_CXX11 == 1
#if DMLC_USE_CXX11
#include "./narray.h"
#include "./dag_engine.h"
#endif
#include "./symbolic.h"

namespace mxnet {
/*! \brief option to pass into the forward function */
struct Option {
/*! \brief whether it is training phase*/
int is_train;
};

/*! \brief operation request type to Forward and Backward */
enum OpReqType {
/*! \brief no operation, do not write anything */
kNullOp,
/*! \brief write gradient to provided space */
kWriteTo,
/*!
* \brief perform an inplace write,
* Target shares memory with one of input arguments.
* This option only happen when
*/
kWriteInplace,
/*! \brief add to the provided space */
kAddTo
};
/*!
* \brief StaticOperator interface
* StaticOperator is a stateful object that can be used to call forward and backprop
Expand All @@ -29,112 +52,85 @@ class StaticOperator {
/*! \brief destructor */
virtual ~StaticOperator() {}
/*!
* \brief describe property of op
* \return a bit map in int
*/
virtual int DescribeProperty() const {
// default most of layer only conatin internal state
return kContainInteralState;
}
/*!
* \brief perform a forward operation of StaticOperator, save the output to TBlob
* \param opt option on Forward such as whether this is training phase
* \brief perform a forward operation of StaticOperator, save the output to TBlob.
* \param opt option on Forward such as whether this is training phase.
* \param ctx runtime context
* \param in_data array of input data, it is const
* \param out_data array of output data,
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \sa OpReqType
*/
virtual void Forward(Option opt,
RunContext ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) = 0;
/*!
* \brief perform a backward operation of the StaticOperator to get the gradient
* \brief Perform a backward Operation, write gradient to the in_grad.
* \param ctx runtime context
* \param grad_next the gradient value we get from output of the StaticOperator
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient, there could be three possible TBlob
* in the each element in the array
* \param req request types of the gradient saving operation
* only inplace will change input data
* \sa GradReqType
* \param out_grad the gradient value we get from output of the StaticOperator
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \sa OpReqType
*/
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) = 0;
/*!
* \brief factory function, create a new StaticOperator
* \param type the type of StaticOperator
* \param ctx the context device type of StaticOperator
* \return a pointer of StaticOperator object
*/
static StaticOperator *Create(const char *type, Context ctx);
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11 == 1
#if DMLC_USE_CXX11
/*!
* \brief operator interface
* operator is an object can be scheduled by DAG engine directly.
* \brief Operator interface.
* Operator is an object can have Forward and Backward function.
*
* This interface relies on NArray. The user should prepare the input NArray and
* output NArray by themselves.
* \sa Operator
* It can be created from
*/
class Operator {
public:
/*! \brief destructor */
virtual ~Operator() {}
/*!
* \brief describe property of op
* \return a bit map in int
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
*/
virtual int DescribeProperty() const = 0;
virtual void Forward() = 0;
/*!
* \brief perform a forward operation of operator, save the output to NArray
* This method only pushes an execution request to the DAG engine, and
* return immediately. Actual execution is conducted by the DAG engine.
* \param opt option on Forward such as whether this is training phase
* \param ctx runtime context
* \param in_data array of input data, it is const
* \param out_data array of output data,
* the space of NArray in out_data must be pre-allocated with InferShape
* \sa NArray
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NArrays specified by grad_in_args_store will be updated accordingly.
*/
virtual void Forward(Option opt,
RunContext ctx,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data) = 0;
/*!
* \brief perform a backward operation of the operator to get the gradient
* This method only pushes an execution request to the DAG engine, and
* return immediately. Actual execution is conducted by the DAG engine.
* \param ctx runtime context
* \param grad_next the gradient value of the output of the operator, used by chain rule.
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient
* \param req request types of the gradient saving operation
* only inplace will change input data
* \sa GradReqType, NArray
*/
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req) = 0;
virtual void Backward() = 0;
/*! \return get array of heads in the operator */
virtual const std::vector<NArray> &head() const = 0;
/*!
* \brief Create a wrapper of static operator to wrap it into Operator.
* This function takes ownership of op
* \param op static operator to wrap from
* \param ctx context of the created operator
* \return a wrapper operator
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
*
* \param ctx the context of binding.
* \param symbol the symbol that specifies the output of Forward pass.
* \param in_args the NArray that stores the input arguments to the symbol.
* \param grad_in_args_store NArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
*/
static Operator *CreateWrapper(StaticOperator *op, Context ctx);
static Operator *Bind(Symbol symbol,
Context ctx,
const std::vector<NArray> &in_args,
const std::vector<NArray> &grad_in_args_store,
const std::vector<OpReqType> &grad_req_type);
}; // class operator
#endif
} // namespace mxnet
Expand Down
Loading

0 comments on commit b5f485f

Please sign in to comment.