Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #18 from antinucleon/master
Browse files Browse the repository at this point in the history
Add
  • Loading branch information
antinucleon committed Aug 17, 2015
2 parents 9f4d31c + 137d109 commit f90b986
Show file tree
Hide file tree
Showing 26 changed files with 1,165 additions and 292 deletions.
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ endif
BIN = test/api_registry_test test/test_storage
OBJ = narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_op_gpu.o fully_connected_gpu.o
CUOBJ += narray_op_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o
endif

.PHONY: clean all test lint doc
Expand All @@ -87,7 +87,10 @@ c_api.o: src/c_api.cc
operator.o: src/operator/static_operator_wrapper.cc
fully_connected_cpu.o: src/operator/fully_connected.cc
fully_connected_gpu.o: src/operator/fully_connected.cu

activation_cpu.o: src/operator/activation.cc
activation_gpu.o: src/operator/activation.cu
elementwise_sum_cpu.o: src/operator/elementwise_sum.cc
elementwise_sum_gpu.o: src/operator/elementwise_sum.cu

lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)
Expand Down
59 changes: 35 additions & 24 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ typedef void *DataIterHandle;
* \return error info
*/
MXNET_DLL const char *MXGetLastError();

//--------------------------------
//-------------------------------------
// Part 1: NArray creation and deletion
//--------------------------------
//-------------------------------------
/*!
* \brief create a NArray handle that is not initialized
* can be used to pass in as mutate variables
Expand Down Expand Up @@ -189,7 +188,6 @@ MXNET_DLL int MXFuncDescribe(FunctionHandle fun,
mx_uint *num_scalars,
mx_uint *num_mutate_vars,
int *type_mask);

/*!
* \brief invoke a function, the array size of passed in arguments
* must match the values in the
Expand Down Expand Up @@ -301,8 +299,8 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief Compose the symbol on other symbols.
*
Expand All @@ -322,6 +320,37 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
mx_uint num_args,
const char** keys,
SymbolHandle* args);
/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
* The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
*
* \param sym symbol handle
* \param num_args numbe of input arguments.
* \param keys the key of keyword args (optional)
* \param arg_ind_ptr the head pointer of the rows in CSR
* \param arg_shape_data the content of the CSR
* \param in_shape_size sizeof the returning array of in_shapes
* \param in_shape_ndim returning array of shape dimensions of eachs input shape.
* \param in_shape_data returning array of pointers to head of the input shape.
* \param out_shape_size sizeof the returning array of out_shapes
* \param out_shape_ndim returning array of shape dimensions of eachs input shape.
* \param out_shape_data returning array of pointers to head of the input shape.
* \param complete whether infer shape completes or more information is needed.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const mx_uint *arg_ind_ptr,
const mx_uint *arg_shape_data,
mx_uint *in_shape_size,
const mx_uint **in_shape_ndim,
const mx_uint ***in_shape_data,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
int *complete);
//--------------------------------------------
// Part 4: operator interface on NArray
//--------------------------------------------
Expand Down Expand Up @@ -352,24 +381,6 @@ MXNET_DLL int MXOpFree(OperatorHandle op);
*/
MXNET_DLL int MXOpDescribeArgs(mx_uint *out_size,
int **out_array);
/*!
* \brief infer shape of unknown input shapes given the known one
* this function do not return the shape of output
* the shapes are packed into a CSR matrix represened by ind_ptr and shape_array
*
* When the function returns, it return a new CSR matrix by updating ind_ptr,
* and return the content in the return value
*
* \param ind_ptr the head pointer of the rows in CSR
* \param shape_array the content of the CSR
* \param out_nout number of output arguments of this operation
* \param out_array another content of CSR with infered shape
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXOpInferShape(mx_uint *ind_ptr,
mx_uint *shape_array,
mx_uint *out_nout,
mx_uint *out_array);
/*!
* \brief call forward on the operator
* \param op the operator handle
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#ifndef MXNET_CONTEXT_H_
#define MXNET_CONTEXT_H_

#include "./base.h"

namespace mxnet {

/*! \brief Context information about the execution enviroment */
Expand Down
42 changes: 27 additions & 15 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ enum OpReqType {
struct OpContext {
/*! \brief whether it is training phase */
int is_train;
/*! \brief Stream we are running on */
void *stream;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief Resources requested by the operator */
std::vector<Resource> requested;
/*!
* \brief set the RunContext related parts
* \param ctx the context
* \brief get mshadow stream from Context
* \return the mshadow stream
* \tparam xpu the device type of the stream
*/
inline void SetRunContext(const RunContext &ctx) {
stream = ctx.stream;
template<typename xpu>
inline mshadow::Stream<xpu>* get_stream() const {
return static_cast<mshadow::Stream<xpu>*>(run_ctx.stream);
}
};

Expand Down Expand Up @@ -84,13 +86,22 @@ class Operator {
const std::vector<TBlob> &out_data) = 0;
/*!
* \brief Perform a Backward Operation, write gradient to the in_grad.
*
* Convention:
* out_grad.size() == OperatorProperty.NumVisibleReturns()
* out_data.size() == OperatorProperty.NumReturns()
* out_data can contain additional invisible returns that remembers the
* state carried from the Forward pass. For example mask in the dropout.
*
* The gradients are passed from visible returns in this function.
*
* \param ctx runtime context available to this call
* \param out_grad the gradient value we get from output of the Operator
* \param out_grad the gradient value we get from of the Operator.
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \sa OpReqType, OpContext
* \sa OpReqType, OpContext, OperatorProperty
*/
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
Expand All @@ -115,6 +126,12 @@ class OperatorProperty {
* \brief virtual destructor
*/
virtual ~OperatorProperty() {}
/*!
* \brief Initialize the Operator by setting the parameters
* This function need to be called before all other functions.
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
Expand Down Expand Up @@ -148,12 +165,6 @@ class OperatorProperty {
virtual int NumVisibleReturns() const {
return NumReturns();
}
/*!
* \brief Set the parameters of the Operator.
* \param name parameter name
* \param val string for the configuration
*/
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
Expand All @@ -166,7 +177,8 @@ class OperatorProperty {
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const = 0;
Expand Down
110 changes: 99 additions & 11 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <memory>
#include <string>
#include <utility>
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include "./base.h"
Expand All @@ -37,22 +38,82 @@ class StaticGraph {
uint32_t source_id;
/*! \brief index of output from the source. */
uint32_t index;
/*! \brief default constructor */
DataEntry() {}
/*!
* \brief constructor with source and index
* \param source_id source id
* \param index node index
*/
DataEntry(uint32_t source_id, uint32_t index)
: source_id(source_id), index(index) {}
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const DataEntry &other) const {
return source_id == other.source_id && index == other.index;
}
/*!
* \brief comparator, allows to use map
* \param other the other entry to compare
* \return whether two entries is smaller than the other
*/
inline bool operator<(const DataEntry &other) const {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}
};
/*! \brief Operation Node in static graph */
/*!
* \brief Operation Node in static graphs.
* There are two types of node, Forward and Backward Node.
*
* - Forward node corresponds to the op.Forward
* - Backward node corresponds to the Backward pass,
* where the corresponding forward node is indicated by backward_source_id.
* The op field in Backward node is nullptr
*
* The reason we explicit support Backward node is to allow special treatment
* such as shape inference and state sharing with Forward pass.
*/
struct Node {
/*! \brief wrapped operator property */
std::unique_ptr<OperatorProperty> op;
/*! \brief name of the node */
std::string name;
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
/*!
* \brief If this field is nonnegative, this indicates this
* Node is corresponds to a Backward Operation of Operator.
* backward_source_id will points to the corresponding Forward Node.
*
* For normal node, this field is -1.
* When the node is a Backward node, the op field will be nullptr
*/
int32_t backward_source_id;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
}
/*! \return whether the node is backward op node */
inline bool is_backward() const {
return backward_source_id != -1;
}
/*! \return whether the node is variable node */
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index is nodes that correspods to arguments */
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief outputs(heads) of the graph */
std::vector<DataEntry> outputs;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
Expand Down Expand Up @@ -85,8 +146,23 @@ class StaticGraph {
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const;
bool InferShape(std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) const;
/*!
* \brief Add a full backward pass in the static graph.
* This function will add gradient nodes for each heads,
* and add the backward pass to backprop the gradients all
* the way to the arguments.
*
* This will change the nodes field in the StaticGraph, but will not change other fields.
* The head and input of Backward pass will be returned by head_grad_nodes and arg_grads.
*
* \param head_grad_nodes used to store the created head gradient inputs for backward pass.
<<<<<<< HEAD
* \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator
*/
void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<std::vector<DataEntry> > *arg_grads);
};

/*!
Expand Down Expand Up @@ -174,19 +250,31 @@ class Symbol {
const std::string& name) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* \param arg_shapes the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
* \param out_shapes Use to store the infered shapes of outputs.
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
/*!
* \brief infer the shapes by providing shapes of known arguments.
* \param known_arg_shapes map of argument name to shape of arguments with known shapes.
* \param arg_shapes used to store infered shapes of arguments.
* \param out_shapes used to store infered shapes of outputs.
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) const;
bool InferShape(const std::unordered_map<std::string, TShape> &known_arg_shapes,
std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/narray.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def shape(self):
pdata = ctypes.POINTER(mx_uint)()
check_call(_LIB.MXNArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[i] for i in range(ndim.value))
return tuple(pdata[:ndim.value])

@property
def context(self):
Expand Down
Loading

0 comments on commit f90b986

Please sign in to comment.