Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

symbol implementation and fix #7

Merged
merged 8 commits into from
Jul 28, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ endif

#BIN = test/test_threaded_engine test/api_registry_test
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o atomic_symbol_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o
CUOBJ =
Expand All @@ -65,7 +65,7 @@ ALIB = api/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o static_operator_gpu.o
CUOBJ += narray_op_gpu.o static_operator_gpu.o atomic_symbol_gpu.o
endif

.PHONY: clean all test lint doc
Expand All @@ -87,7 +87,9 @@ static_operator_gpu.o: src/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
registry.o: src/registry.cc
mxnet_api.o: api/mxnet_api.cc
operator.o: src/operator/operator.cc
operator.o: src/operator/static_operator_wrapper.cc
atomic_symbol_cpu.o: src/symbol/fully_connect_sym.cc
atomic_symbol_gpu.o: src/symbol/fully_connect_sym.cu

api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)
Expand Down
7 changes: 4 additions & 3 deletions include/mxnet/atomic_symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "./tensor_blob.h"

namespace mxnet {
class Operator;
class StaticOperator;
/*!
* \brief AtomicSymbol is the base class of all atomic symbols.
* This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance
Expand Down Expand Up @@ -54,7 +54,7 @@ class AtomicSymbol {
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
virtual bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) = 0;
virtual bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) const = 0;
/*!
* \brief Copy this AtomicSymbol and returns a pointer to the copied object.
* this is a virtual function because different subclass of AtomicSymbol would copy differently.
Expand All @@ -66,7 +66,8 @@ class AtomicSymbol {
* Bind function of AtomicSymbol does not return NArrayOperator, but static operator.
* Calling bind from the Symbol wrapper would generate a NArrayOperator.
*/
virtual Operator* Bind(Context ctx) const = 0;
template<typename xpu>
StaticOperator* Bind(Context ctx) const;
/*!
* \brief return the type string of the atomic symbol
* subclasses override this function.
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class NArray {
DAGEngine::Get()->WaitForVar(ptr_->var);
}
/*! \return the associated DAG variable of the narray.*/
inline DAGEngine::Variable Var() const {
inline DAGEngine::Variable var() const {
return ptr_->var;
}
/*!
Expand Down
74 changes: 22 additions & 52 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,93 +17,63 @@

namespace mxnet {
/*!
* \brief static operator interface (current interface have not yet todo with scheduler),
* operator is a stateful object that can be used to call forward and backprop
*
* This interface relies on pre-allocated memory in TBlob, the caller need to set
* the memory region in TBlob correctly before calling Forward and Backward
* \brief operator interface
* operator is an object can be scheduled by DAG engine directly.
*
* This interface relies on NArray. The user should prepare the input NArray and
* output NArray by themselves.
* \sa Operator
*/
class Operator {
public:
/*!
* \brief construct Operator from StaticOperator and Context
* \param op StaticOperator to wrap
* \param ctx Context of the Operator
*/
Operator(StaticOperator* op, Context ctx);
/*!
* \brief get types of input argument of this oeprator
* \return a vector corresponding to type of each argument
* this order is same as the order of inputs in Forward, InferShape and Backward
*/
virtual std::vector<ArgType> DescribeArgs() const;
/*!
* \brief describe property of op
* \return a bit map in int
*/
virtual int DescribeProperty() const;
/*!
* \brief set param for the operator from string
* \param name parameter name
* \param val string for configuration
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief inter the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
*/
virtual void InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape);

/*!
* \brief set the context of the Operator
* \brief set the global context of the Operator
* \param ctx the context to be set to
*/
virtual void SetContext(Context ctx);
/*!
* \brief perform a forward operation of operator, save the output to TBlob
* \brief perform a forward operation of operator, save the output to NArray
* This method only pushes an execution request to the DAG engine, and
* return immediately. Actual execution is conducted by the DAG engine.
* \param opt option on Forward such as whether this is training phase
* \param ctx runtime context
* \param in_data array of input data, it is const
* \param out_data array of output data,
* the space of TBlob in out_data must be pre-allocated with InferShape
* the space of NArray in out_data must be pre-allocated with InferShape
* \sa NArray
*/
virtual void Forward(Option opt,
RunContext ctx,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data);
const std::vector<NArray> &out_data) = 0;
/*!
* \brief perform a backward operation of the operator to get the gradient
* This method only pushes an execution request to the DAG engine, and
* return immediately. Actual execution is conducted by the DAG engine.
* \param ctx runtime context
* \param grad_next the gradient value we get from output of the operator
* \param grad_next the gradient value of the output of the operator, used by chain rule.
* \param in_data the array of input data
* \param out_grad array of output gradient, there could be three possible TBlob
* in the each element in the array
* \param out_grad array of output gradient
* \param req request types of the gradient saving operation
* only inplace will change input data
* \sa GradReqType
* \sa GradReqType, NArray
*/
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req);
const std::vector<GradReqType> &req) = 0;

private:
/* \brief the static operator */
StaticOperator* op;
protected:
/**
* \brief the global context denots the device info.
*/
Context global_ctx;
};
}; // class operator
} // namespace mxnet
#endif // MXNET_OPERATOR_H_
32 changes: 1 addition & 31 deletions include/mxnet/static_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ namespace mxnet {
*/
class StaticOperator {
public:
/*!
* \brief get types of input argument of this oeprator
* \return a vector corresponding to type of each argument
* this order is same as the order of inputs in Forward, InferShape and Backward
*/
virtual std::vector<ArgType> DescribeArgs() const {
// default most of layers only have one data argument
return std::vector<ArgType>(1, kDataArg);
}
/*!
* \brief describe property of op
* \return a bit map in int
Expand All @@ -40,27 +31,6 @@ class StaticOperator {
// default most of layer only conatin internal state
return kContainInteralState;
}
/*!
* \brief set param for the StaticOperator from string
* \param name parameter name
* \param val string for configuration
*/
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief inter the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the StaticOperator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the StaticOperator
* InferShape will modify the vector to fill output TShape
*/
virtual void InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) = 0;
/*!
* \brief perform a forward operation of StaticOperator, save the output to TBlob
* \param opt option on Forward such as whether this is training phase
Expand Down Expand Up @@ -90,7 +60,7 @@ class StaticOperator {
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) = 0;
/*!
* \brief factory unction, create a new StaticOperator
* \brief factory function, create a new StaticOperator
* \param type the type of StaticOperator
* \param ctx the context device type of StaticOperator
* \return a pointer of StaticOperator object
Expand Down
8 changes: 4 additions & 4 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include <unordered_map>
#include "./base.h"
#include "./tensor_blob.h"
#include "./operator.h"

namespace mxnet {
class NArrayOperator;
/*!
* \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol
* should support expressions and often passed by value. While AtomicSymbol have many subclasses,
Expand Down Expand Up @@ -67,11 +67,11 @@ class Symbol {
*/
virtual ~Symbol() {}
/*!
* \brief bind to device and returns an NArrayOperator.
* \brief bind to device and returns an operator.
* \param ctx context of the operator
* \return returns the pointer to a created NArrayOperator. It is on the user to delete.
* \return returns the pointer to a created operator. It is on the user to delete.
*/
virtual NArrayOperator* Bind(Context ctx) const { return nullptr; }
virtual Operator* Bind(Context ctx) const { return nullptr; }
/*!
* \brief copy the symbol
* \return a deep copy of the graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,22 @@
* \brief the implementation of narray operator
* \author Naiyan Wang
*/
#include <mxnet/operator.h>
#include "./static_operator_wrapper.h"

namespace mxnet {

Operator::Operator(StaticOperator* op, Context ctx) {
StaticOperatorWrapper::StaticOperatorWrapper(StaticOperator* op, Context ctx) {
this->op = op;
this->global_ctx = ctx;
}
/*!
* \brief get types of input argument of this oeprator
* \return a vector corresponding to type of each argument
* this order is same as the order of inputs in Forward, InferShape and Backward
*/
std::vector<ArgType> Operator::DescribeArgs() const {
// default most of layers only have one data argument
return op->DescribeArgs();
}
/*!
* \brief describe property of op
* \return a bit map in int
*/
int Operator::DescribeProperty() const {
int StaticOperatorWrapper::DescribeProperty() const {
// default most of layer only conatin internal state
return op->DescribeProperty();
}
/*!
* \brief set param for the operator from string
* \param name parameter name
* \param val string for configuration
*/
void Operator::SetParam(const char *name, const char *val) {
op->SetParam(name, val);
}
/*!
* \brief inter the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
*/
void Operator::InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
op->InferShape(in_shape, out_shape);
}
/*!
* \brief perform a forward operation of operator, save the output to TBlob
* \param opt option on Forward such as whether this is training phase
Expand All @@ -62,7 +28,7 @@ namespace mxnet {
* \param out_data array of output data,
* the space of TBlob in out_data must be pre-allocated with InferShape
*/
void Operator::Forward(Option opt,
void StaticOperatorWrapper::Forward(Option opt,
RunContext ctx,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data) {
Expand All @@ -71,11 +37,11 @@ namespace mxnet {
std::vector<TBlob> in;
std::vector<TBlob> out;
for (size_t i = 0; i < in_data.size(); ++i) {
used_var.push_back(in_data[i].Var());
used_var.push_back(in_data[i].var());
in.push_back(in_data[i].data());
}
for (size_t i = 0; i < out_data.size(); ++i) {
mutate_var.push_back(out_data[i].Var());
mutate_var.push_back(out_data[i].var());
out.push_back(out_data[i].data());
}
DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) {
Expand All @@ -93,7 +59,7 @@ namespace mxnet {
* only inplace will change input data
* \sa GradReqType
*/
void Operator::Backward(RunContext ctx,
void StaticOperatorWrapper::Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_grad,
Expand All @@ -104,23 +70,23 @@ namespace mxnet {
std::vector<TBlob> grad_out;
std::vector<TBlob> data;
for (size_t i = 0; i < grad_next.size(); ++i) {
used_var.push_back(grad_next[i].Var());
used_var.push_back(grad_next[i].var());
grad_in.push_back(grad_next[i].data());
}
for (size_t i = 0; i < in_data.size(); ++i) {
used_var.push_back(in_data[i].Var());
used_var.push_back(in_data[i].var());
data.push_back(in_data[i].data());
}
for (size_t i = 0; i < out_grad.size(); ++i) {
mutate_var.push_back(out_grad[i].Var());
mutate_var.push_back(out_grad[i].var());
grad_out.push_back(out_grad[i].data());
}
DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) {
op->Backward(ctx, grad_in, data, grad_out, req);
}, global_ctx, used_var, mutate_var);
}

void Operator::SetContext(Context ctx) {
void StaticOperatorWrapper::SetContext(Context ctx) {
this->global_ctx = ctx;
}

Expand Down
Loading