Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

new interface #14

Merged
merged 3 commits into from
Aug 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ endif

#BIN = test/test_threaded_engine test/api_registry_test
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o
OBJ = storage.o narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connect_op_cpu.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o static_operator_gpu.o fully_connect_op_gpu.o
CUOBJ += narray_op_gpu.o fully_connect_op_gpu.o
endif

.PHONY: clean all test lint doc
Expand Down
33 changes: 0 additions & 33 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,6 @@ typedef mshadow::index_t index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;

/*! \brief option to pass into the forward function */
struct Option {
/*! \brief whether it is training phase*/
int is_train;
};
/*! \brief gradient request type the request can have */
enum GradReqType {
/*! \brief no operation, do not write gradient */
kNullOp = 0,
/*! \brief write gradient to provided space */
kWriteTo = 1,
/*! \brief same as kWriteTo, but provided space is same as space of input-data */
kWriteInplace = 2,
/*! \brief add to the provided space */
kAddTo = 3
};
/*! \brief input argument type of the operator have */
enum ArgType {
/*! \brief data argument */
kDataArg = 0,
/*! \brief weight argument */
kWeightArg = 1,
/*! \brief bias argument */
kBiasArg = 2
};
/*! \brief Property for engine schedule */
enum Property {
/*! \brief Op contains interanl state, won't influence engine schedule */
kContainInteralState = 1,
/*! \brief Op forward require random number, will influence engine schedule */
kForwardRequireRnd = 2,
};

/*! \brief context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
Expand Down
150 changes: 73 additions & 77 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,34 @@
#include <dmlc/base.h>
#include <vector>
#include "./base.h"
#if DMLC_USE_CXX11 == 1
#if DMLC_USE_CXX11
#include "./narray.h"
#include "./dag_engine.h"
#endif
#include "./symbolic.h"

namespace mxnet {
/*! \brief option to pass into the forward function */
struct Option {
/*! \brief whether it is training phase*/
int is_train;
};

/*! \brief operation request type to Forward and Backward */
enum OpReqType {
/*! \brief no operation, do not write anything */
kNullOp,
/*! \brief write gradient to provided space */
kWriteTo,
/*!
* \brief perform an inplace write,
* Target shares memory with one of input arguments.
* This option only happen when
*/
kWriteInplace,
/*! \brief add to the provided space */
kAddTo
};
/*!
* \brief StaticOperator interface
* StaticOperator is a stateful object that can be used to call forward and backprop
Expand All @@ -29,112 +52,85 @@ class StaticOperator {
/*! \brief destructor */
virtual ~StaticOperator() {}
/*!
* \brief describe property of op
* \return a bit map in int
*/
virtual int DescribeProperty() const {
// default most of layer only conatin internal state
return kContainInteralState;
}
/*!
* \brief perform a forward operation of StaticOperator, save the output to TBlob
* \param opt option on Forward such as whether this is training phase
* \brief perform a forward operation of StaticOperator, save the output to TBlob.
* \param opt option on Forward such as whether this is training phase.
* \param ctx runtime context
* \param in_data array of input data, it is const
* \param out_data array of output data,
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \sa OpReqType
*/
virtual void Forward(Option opt,
RunContext ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) = 0;
/*!
* \brief perform a backward operation of the StaticOperator to get the gradient
* \brief Perform a backward Operation, write gradient to the in_grad.
* \param ctx runtime context
* \param grad_next the gradient value we get from output of the StaticOperator
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient, there could be three possible TBlob
* in the each element in the array
* \param req request types of the gradient saving operation
* only inplace will change input data
* \sa GradReqType
* \param out_grad the gradient value we get from output of the StaticOperator
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \sa OpReqType
*/
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) = 0;
/*!
* \brief factory function, create a new StaticOperator
* \param type the type of StaticOperator
* \param ctx the context device type of StaticOperator
* \return a pointer of StaticOperator object
*/
static StaticOperator *Create(const char *type, Context ctx);
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11 == 1
#if DMLC_USE_CXX11
/*!
* \brief operator interface
* operator is an object can be scheduled by DAG engine directly.
* \brief Operator interface.
* Operator is an object can have Forward and Backward function.
*
* This interface relies on NArray. The user should prepare the input NArray and
* output NArray by themselves.
* \sa Operator
* It can be created from
*/
class Operator {
public:
/*! \brief destructor */
virtual ~Operator() {}
/*!
* \brief describe property of op
* \return a bit map in int
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
*/
virtual int DescribeProperty() const = 0;
virtual void Forward() = 0;
/*!
* \brief perform a forward operation of operator, save the output to NArray
* This method only pushes an execution request to the DAG engine, and
* return immediately. Actual execution is conducted by the DAG engine.
* \param opt option on Forward such as whether this is training phase
* \param ctx runtime context
* \param in_data array of input data, it is const
* \param out_data array of output data,
* the space of NArray in out_data must be pre-allocated with InferShape
* \sa NArray
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
* After this operation, NArrays specified by grad_in_args_store will be updated accordingly.
*/
virtual void Forward(Option opt,
RunContext ctx,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data) = 0;
/*!
* \brief perform a backward operation of the operator to get the gradient
* This method only pushes an execution request to the DAG engine, and
* return immediately. Actual execution is conducted by the DAG engine.
* \param ctx runtime context
* \param grad_next the gradient value of the output of the operator, used by chain rule.
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient
* \param req request types of the gradient saving operation
* only inplace will change input data
* \sa GradReqType, NArray
*/
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req) = 0;
virtual void Backward() = 0;
/*! \return get array of heads in the operator */
virtual const std::vector<NArray> &head() const = 0;
/*!
* \brief Create a wrapper of static operator to wrap it into Operator.
* This function takes ownership of op
* \param op static operator to wrap from
* \param ctx context of the created operator
* \return a wrapper operator
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
*
* \param ctx the context of binding.
* \param symbol the symbol that specifies the output of Forward pass.
* \param in_args the NArray that stores the input arguments to the symbol.
* \param grad_in_args_store NArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
*/
static Operator *CreateWrapper(StaticOperator *op, Context ctx);
static Operator *Bind(Symbol symbol,
Context ctx,
const std::vector<NArray> &in_args,
const std::vector<NArray> &grad_in_args_store,
const std::vector<OpReqType> &grad_req_type);
}; // class operator
#endif
} // namespace mxnet
Expand Down
Loading