Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add #18

Merged
merged 8 commits into from
Aug 17, 2015
Merged

Add #18

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ endif
BIN = test/api_registry_test test/test_storage
OBJ = narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o fully_connected_gpu.o
CUOBJ += narray_op_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o
endif

.PHONY: clean all test lint doc
Expand All @@ -87,7 +87,10 @@ c_api.o: src/c_api.cc
operator.o: src/operator/static_operator_wrapper.cc
fully_connected_cpu.o: src/operator/fully_connected.cc
fully_connected_gpu.o: src/operator/fully_connected.cu

activation_cpu.o: src/operator/activation.cc
activation_gpu.o: src/operator/activation.cu
elementwise_sum_cpu.o: src/operator/elementwise_sum.cc
elementwise_sum_gpu.o: src/operator/elementwise_sum.cu

lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)
Expand Down
59 changes: 35 additions & 24 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ typedef void *DataIterHandle;
* \return error info
*/
MXNET_DLL const char *MXGetLastError();

//--------------------------------
//-------------------------------------
// Part 1: NArray creation and deletion
//--------------------------------
//-------------------------------------
/*!
* \brief create a NArray handle that is not initialized
* can be used to pass in as mutate variables
Expand Down Expand Up @@ -189,7 +188,6 @@ MXNET_DLL int MXFuncDescribe(FunctionHandle fun,
mx_uint *num_scalars,
mx_uint *num_mutate_vars,
int *type_mask);

/*!
* \brief invoke a function, the array size of passed in arguments
* must match the values in the
Expand Down Expand Up @@ -301,8 +299,8 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief Compose the symbol on other symbols.
*
Expand All @@ -322,6 +320,37 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
mx_uint num_args,
const char** keys,
SymbolHandle* args);
/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
* The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
*
* \param sym symbol handle
* \param num_args numbe of input arguments.
* \param keys the key of keyword args (optional)
* \param arg_ind_ptr the head pointer of the rows in CSR
* \param arg_shape_data the content of the CSR
* \param in_shape_size sizeof the returning array of in_shapes
* \param in_shape_ndim returning array of shape dimensions of eachs input shape.
* \param in_shape_data returning array of pointers to head of the input shape.
* \param out_shape_size sizeof the returning array of out_shapes
* \param out_shape_ndim returning array of shape dimensions of eachs input shape.
* \param out_shape_data returning array of pointers to head of the input shape.
* \param complete whether infer shape completes or more information is needed.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
int *complete);
//--------------------------------------------
// Part 4: operator interface on NArray
//--------------------------------------------
Expand Down Expand Up @@ -352,24 +381,6 @@ MXNET_DLL int MXOpFree(OperatorHandle op);
*/
MXNET_DLL int MXOpDescribeArgs(mx_uint *out_size,
int **out_array);
/*!
* \brief infer shape of unknown input shapes given the known one
* this function do not return the shape of output
* the shapes are packed into a CSR matrix represened by ind_ptr and shape_array
*
* When the function returns, it return a new CSR matrix by updating ind_ptr,
* and return the content in the return value
*
* \param ind_ptr the head pointer of the rows in CSR
* \param shape_array the content of the CSR
* \param out_nout number of output arguments of this operation
* \param out_array another content of CSR with infered shape
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXOpInferShape(mx_uint *ind_ptr,
mx_uint *shape_array,
mx_uint *out_nout,
mx_uint *out_array);
/*!
* \brief call forward on the operator
* \param op the operator handle
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#ifndef MXNET_CONTEXT_H_
#define MXNET_CONTEXT_H_

#include "./base.h"

namespace mxnet {

/*! \brief Context information about the execution enviroment */
Expand Down
42 changes: 27 additions & 15 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ enum OpReqType {
struct OpContext {
/*! \brief whether it is training phase */
int is_train;
/*! \brief Stream we are running on */
void *stream;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief Resources requested by the operator */
std::vector<Resource> requested;
/*!
* \brief set the RunContext related parts
* \param ctx the context
* \brief get mshadow stream from Context
* \return the mshadow stream
* \tparam xpu the device type of the stream
*/
inline void SetRunContext(const RunContext &ctx) {
stream = ctx.stream;
template<typename xpu>
inline mshadow::Stream<xpu>* get_stream() const {
return static_cast<mshadow::Stream<xpu>*>(run_ctx.stream);
}
};

Expand Down Expand Up @@ -84,13 +86,22 @@ class Operator {
const std::vector<TBlob> &out_data) = 0;
/*!
* \brief Perform a Backward Operation, write gradient to the in_grad.
*
* Convention:
* out_grad.size() == OperatorProperty.NumVisibleReturns()
* out_data.size() == OperatorProperty.NumReturns()
* out_data can contain additional invisible returns that remembers the
* state carried from the Forward pass. For example mask in the dropout.
*
* The gradients are passed from visible returns in this function.
*
* \param ctx runtime context available to this call
* \param out_grad the gradient value we get from output of the Operator
* \param out_grad the gradient value we get from of the Operator.
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \sa OpReqType, OpContext
* \sa OpReqType, OpContext, OperatorProperty
*/
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
Expand All @@ -115,6 +126,12 @@ class OperatorProperty {
* \brief virtual destructor
*/
virtual ~OperatorProperty() {}
/*!
* \brief Initialize the Operator by setting the parameters
* This function need to be called before all other functions.
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
Expand Down Expand Up @@ -148,12 +165,6 @@ class OperatorProperty {
virtual int NumVisibleReturns() const {
return NumReturns();
}
/*!
* \brief Set the parameters of the Operator.
* \param name parameter name
* \param val string for the configuration
*/
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
Expand All @@ -166,7 +177,8 @@ class OperatorProperty {
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const = 0;
Expand Down
110 changes: 99 additions & 11 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <memory>
#include <string>
#include <utility>
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include "./base.h"
Expand All @@ -37,22 +38,82 @@ class StaticGraph {
uint32_t source_id;
/*! \brief index of output from the source. */
uint32_t index;
/*! \brief default constructor */
DataEntry() {}
/*!
* \brief constructor with source and index
* \param source_id source id
* \param index node index
*/
DataEntry(uint32_t source_id, uint32_t index)
: source_id(source_id), index(index) {}
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const DataEntry &other) const {
return source_id == other.source_id && index == other.index;
}
/*!
* \brief comparator, allows to use map
* \param other the other entry to compare
* \return whether two entries is smaller than the other
*/
inline bool operator<(const DataEntry &other) const {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}
};
/*! \brief Operation Node in static graph */
/*!
* \brief Operation Node in static graphs.
* There are two types of node, Forward and Backward Node.
*
* - Forward node corresponds to the op.Forward
* - Backward node corresponds to the Backward pass,
* where the corresponding forward node is indicated by backward_source_id.
* The op field in Backward node is nullptr
*
* The reason we explicit support Backward node is to allow special treatment
* such as shape inference and state sharing with Forward pass.
*/
struct Node {
/*! \brief wrapped operator property */
std::unique_ptr<OperatorProperty> op;
/*! \brief name of the node */
std::string name;
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
/*!
* \brief If this field is nonnegative, this indicates this
* Node is corresponds to a Backward Operation of Operator.
* backward_source_id will points to the corresponding Forward Node.
*
* For normal node, this field is -1.
* When the node is a Backward node, the op field will be nullptr
*/
int32_t backward_source_id;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
}
/*! \return whether the node is backward op node */
inline bool is_backward() const {
return backward_source_id != -1;
}
/*! \return whether the node is variable node */
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index is nodes that correspods to arguments */
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief outputs(heads) of the graph */
std::vector<DataEntry> outputs;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
Expand Down Expand Up @@ -85,8 +146,23 @@ class StaticGraph {
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const;
bool InferShape(std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) const;
/*!
* \brief Add a full backward pass in the static graph.
* This function will add gradient nodes for each heads,
* and add the backward pass to backprop the gradients all
* the way to the arguments.
*
* This will change the nodes field in the StaticGraph, but will not change other fields.
* The head and input of Backward pass will be returned by head_grad_nodes and arg_grads.
*
* \param head_grad_nodes used to store the created head gradient inputs for backward pass.
<<<<<<< HEAD
* \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator
*/
void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<std::vector<DataEntry> > *arg_grads);
};

/*!
Expand Down Expand Up @@ -174,19 +250,31 @@ class Symbol {
const std::string& name) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* \param arg_shapes the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
* \param out_shapes Use to store the infered shapes of outputs.
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
/*!
* \brief infer the shapes by providing shapes of known arguments.
* \param known_arg_shapes map of argument name to shape of arguments with known shapes.
* \param arg_shapes used to store infered shapes of arguments.
* \param out_shapes used to store infered shapes of outputs.
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) const;
bool InferShape(const std::unordered_map<std::string, TShape> &known_arg_shapes,
std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/narray.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def shape(self):
pdata = ctypes.POINTER(mx_uint)()
check_call(_LIB.MXNArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[i] for i in range(ndim.value))
return tuple(pdata[:ndim.value])

@property
def context(self):
Expand Down
Loading