Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
a bit more minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 15, 2015
1 parent 3626c34 commit 80a3d42
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 57 deletions.
37 changes: 0 additions & 37 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,6 @@ typedef mshadow::index_t index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;

/*! \brief context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
int dev_mask;
/*! \brief device id we are going to run it on */
int dev_id;
/*! \brief constructor */
Context() : dev_mask(cpu::kDevMask), dev_id(0) {}
/*!
* \brief constructor of context
* \param dev_mask the device mask
* \param dev_id the device id
*/
Context(int dev_mask, int dev_id)
: dev_mask(dev_mask), dev_id(dev_id) {}
/*!
* \brief check if current context equals another one
* \param b another context to compare
* \return whether dev mask and id are same
*/
inline bool operator==(const Context &b) const {
return dev_mask == b.dev_mask && dev_id == b.dev_id;
}
};


/*!
* \brief execution context provides the information needed
* in runtime to actually execute the operation
*/
struct RunContext {
/*!
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
};

/*! \brief dynamic shape type */
typedef mshadow::TShape TShape;
/*! \brief storage container type */
Expand Down
80 changes: 80 additions & 0 deletions include/mxnet/context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*!
* Copyright (c) 2015 by Contributors
* \file context.h
* \brief Context information and resources in mxnet.
*/
#ifndef MXNET_CONTEXT_H_
#define MXNET_CONTEXT_H_

namespace mxnet {

/*! \brief Context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
int dev_mask;
/*! \brief device id we are going to run it on */
int dev_id;
/*! \brief constructor */
Context() : dev_mask(cpu::kDevMask), dev_id(0) {}
/*!
* \brief constructor of context
* \param dev_mask the device mask
* \param dev_id the device id
*/
Context(int dev_mask, int dev_id)
: dev_mask(dev_mask), dev_id(dev_id) {}
/*!
* \brief check if current context equals another one
* \param b another context to compare
* \return whether dev mask and id are same
*/
inline bool operator==(const Context &b) const {
return dev_mask == b.dev_mask && dev_id == b.dev_id;
}
};

/*!
* \brief execution time context.
* The information needed in runtime for actual execution.
*/
struct RunContext {
/*!
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
};

/*!
* \brief Additional resources
*/
struct Resource {
/*! \brief Resource type, indicating what the pointer type is */
enum Type {
/*! \brief mshadow::Random<xpu> object */
kRandom,
/*! \brief Temporal space */
kTempSpace
};
/*! \brief pointer to the resource */
void *ptr;
};

/*!
* \brief The resources that can be requested by Operator
*/
struct ResourceRequest {
/*! \brief type of resources */
Resource::Type type;
/*! \brief size requirment if it is an temp space request */
size_t space_size;
/*! \brief default constructor */
ResourceRequest() {}
/*!
* \brief default constructor, allow implicit conversion
* \param type type of resources
*/
ResourceRequest(Resource::Type type) : type(type) {} // NOLINT(*)
};

} // namespace mxnet
#endif // MXNET_CONTEXT_H_
1 change: 1 addition & 0 deletions include/mxnet/dag_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <functional>
#include <vector>
#include "./base.h"
#include "./context.h"

namespace mxnet {
/*!
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include <dmlc/logging.h>
#include <memory>
#include "./base.h"
#include "./context.h"
#include "./storage.h"
#include "./context.h"
#include "./dag_engine.h"
// check c++11
#if DMLC_USE_CXX11 == 0
Expand Down
83 changes: 67 additions & 16 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@
#include <string>
#include <utility>
#include "./base.h"
#include "./context.h"

namespace mxnet {
/*! \brief option to pass into the forward function */
struct Option {
/*! \brief whether it is training phase*/
int is_train;
};

/*! \brief operation request type to Forward and Backward */
enum OpReqType {
/*! \brief no operation, do not write anything */
Expand All @@ -36,6 +31,28 @@ enum OpReqType {
kAddTo
};

/*!
* \brief All the possible information needed by Operator.Forward and Backward
* This is the superset of RunContext.
* We use this data structure to bookkeep everything needed by Forward and Backward.
* \sa Resource
*/
struct OpContext {
/*! \brief whether it is training phase */
int is_train;
/*! \brief Stream we are running on */
void *stream;
/*! \brief Resources requested by the operator */
std::vector<Resource> requested;
/*!
* \brief set the RunContext related parts
* \param ctx the context
*/
inline void SetRunContext(const RunContext &ctx) {
stream = ctx.stream;
}
};

/*!
* \brief Operator interface.
* Operator defins basic operation unit of optimized computation graph in mxnet.
Expand All @@ -54,30 +71,28 @@ class Operator {
virtual ~Operator() {}
/*!
* \brief perform a forward operation of Operator, save the output to TBlob.
* \param opt option on Forward such as whether this is training phase.
* \param ctx runtime context
* \param ctx runtime context available to this call
* \param in_data array of input data, it is const
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \sa OpReqType
* \sa OpReqType, OpContext
*/
virtual void Forward(Option opt,
RunContext ctx,
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) = 0;
/*!
* \brief Perform a backward Operation, write gradient to the in_grad.
* \param ctx runtime context
* \brief Perform a Backward Operation, write gradient to the in_grad.
* \param ctx runtime context available to this call
* \param out_grad the gradient value we get from output of the Operator
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \sa OpReqType
* \sa OpReqType, OpContext
*/
virtual void Backward(RunContext ctx,
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
Expand Down Expand Up @@ -114,10 +129,25 @@ class OperatorProperty {
virtual std::vector<std::string> ListReturns() const {
return {"output"};
}
/*! \return number of outputs of the Operator */
/*! \return number of real return values of the Operator */
virtual int NumReturns() const {
return 1;
}
/*!
* \brief get number of visible return values during Symbol creation.
* If NumVisibleReturns() = k, and NumReturns() = n.
* The first k returns will be presented in the resulting symbol.
*
* The rest of the returns can be used for auxiliary states for Backward.
* For example, Dropout will return [data, mask], with NumVisibleReturns() == 1.
* So when user call sym = Dropout(input), only data is presented in sym.
* But all the returns will be presented in out_data parameter of Backward if requested.
*
* \return number of default return values
*/
virtual int NumVisibleReturns() const {
return NumReturns();
}
/*!
* \brief Set the parameters of the Operator.
* \param name parameter name
Expand Down Expand Up @@ -154,6 +184,27 @@ class OperatorProperty {
* subclasses override this function.
*/
virtual std::string TypeString() const = 0;
//--------------------------------------------------------
// All the below functions are optional to override.
//--------------------------------------------------------
/*!
* \brief Declare additional resource required in forward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> ForwardResource() const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Decalre additional resource required in backward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> BackwardResource() const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Declare the input requirement of Backward pass.
*
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef MXNET_STORAGE_H_
#define MXNET_STORAGE_H_
#include "./base.h"
#include "./context.h"

namespace mxnet {
/*! \brief memory allocator of storage */
Expand Down
1 change: 1 addition & 0 deletions src/narray/narray_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <dmlc/logging.h>
#include <mshadow/tensor.h>
#include <mxnet/base.h>
#include <mxnet/context.h>

namespace mxnet {
/*! \brief namespace to support all possible NArray operator */
Expand Down
5 changes: 2 additions & 3 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class FullyConnectedOp : public Operator {
this->param_ = p;
}

virtual void Forward(Option opt,
RunContext ctx,
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) {
Expand All @@ -57,7 +56,7 @@ class FullyConnectedOp : public Operator {
}
}

virtual void Backward(RunContext ctx,
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
Expand Down
2 changes: 1 addition & 1 deletion src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ bool Symbol::InferShape(std::vector<TShape> *in_shape,
Symbol Symbol::Create(OperatorProperty *op) {
// use special representation for atomic symbol
auto node = std::make_shared<Node>(op, "");
size_t nret = op->NumReturns();
size_t nret = op->NumVisibleReturns();
Symbol s;
for (uint32_t i = 0; i < nret; ++i) {
s.heads_.push_back(DataEntry(node, i));
Expand Down

0 comments on commit 80a3d42

Please sign in to comment.