From 3d515836359966f40a649a1a50f7e48e75d4d009 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 20 Jun 2015 21:21:27 -0600 Subject: [PATCH 01/11] chg readme --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d2d7f138e26e..647e0f02d093 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,15 @@ # MXNet -This is an experimental project to put cxxnet and minerva together, nothing is working yet. +This is a project that combines lessons and ideas we learnt from [cxxnet](https://github.com/dmlc/cxxnet), [minerva](https://github.com/dmlc/minerva) and [purine2](https://github.com/purine/purine2). +- The interface is designed in collaboration by authors of three projects. +- Nothing is yet working # Guidelines * Use google c style * Put module header in [include](include) - - move them to ```project-name/include``` when we finalized the name * Depend on [dmlc-core](https://github.com/dmlc/dmlc-core) * Doxygen comment every function, class and variable for the module headers - Ref headers in [dmlc-core/include](https://github.com/dmlc/dmlc-core/tree/master/include/dmlc) - Use the same style as dmlc-core -* Try write some use-cases of interface in [test](test) - - They do not need to link, but need to pass compile * Minimize dependency, if possible only depend on dmlc-core * Macro Guard CXX11 code by - Try to make interface compile when c++11 was not avaialable(but with some functionalities pieces missing) From c7482d123d012dfc75f4b906d0415a44fdb22ffd Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 27 Jun 2015 11:27:53 -0600 Subject: [PATCH 02/11] update on function registry --- include/mxnet/api_registry.h | 27 ++++++++++++++------------- include/mxnet/narray.h | 34 ++++++++++++++++++++++++++++++++++ src/api_registry.cc | 10 +++++----- src/narray/narray.cc | 20 ++++++++++++++++++++ src/narray/narray_op.h | 1 + test/api_registry_test.cc | 2 +- 6 files changed, 75 insertions(+), 19 deletions(-) diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 408601006910..c6cbdb6db89b 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -18,12 +18,12 @@ namespace mxnet { /*! \brief registry of NArray functions */ -class NArrayFunRegistry { +class FunctionRegistry { public: /*! \brief definition of NArray function */ typedef std::function NArrayFun; + NArray **mutate_vars)> Function; /*! \brief registry entry */ struct Entry { /*! \brief function name */ @@ -35,7 +35,7 @@ class NArrayFunRegistry { /*! \brief number of scalars used by this function */ unsigned num_scalars; /*! \brief the real function */ - NArrayFun body; + Function body; /*! * \brief constructor * \param name name of the function @@ -75,7 +75,7 @@ class NArrayFunRegistry { * \param f function body to set * \return ref to the registered entry, used to set properties */ - inline Entry &set_body(NArrayFun f) { + inline Entry &set_body(Function f) { body = f; return *this; } /*! @@ -106,7 +106,7 @@ class NArrayFunRegistry { } }; // Entry /*! \return get a singleton */ - static NArrayFunRegistry *Get(); + static FunctionRegistry *Get(); /*! * \brief register a name function under name * \param name name of the function @@ -114,17 +114,18 @@ class NArrayFunRegistry { */ Entry &Register(const std::string name); /*! \return list of functions in the registry */ - inline const std::vector &List() const { - return fun_list_; + inline static const std::vector &List() { + return Get()->fun_list_; } /*! * \brief find an function entry with corresponding name * \param name name of the function * \return the corresponding function, can be NULL */ - inline const Entry *Find(const std::string &name) const { - auto p = fmap_.find(name); - if (p != fmap_.end()) { + inline static const Entry *Find(const std::string &name) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(name); + if (p != fmap.end()) { return p->second; } else { return nullptr; @@ -137,9 +138,9 @@ class NArrayFunRegistry { /*! \brief map of name->function */ std::map fmap_; /*! \brief constructor */ - NArrayFunRegistry() {} + FunctionRegistry() {} /*! \brief destructor */ - ~NArrayFunRegistry(); + ~FunctionRegistry(); }; /*! @@ -159,7 +160,7 @@ class NArrayFunRegistry { */ #define REGISTER_NARRAY_FUN(name) \ static auto __ ## name ## _narray_fun__ = \ - ::mxnet::NArrayFunRegistry::Get()->Register("" # name) + ::mxnet::FunctionRegistry::Get()->Register("" # name) } // namespace mxnet #endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 0fd40512c35e..8e189a6ee960 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -59,6 +59,40 @@ class NArray { inline bool is_none() const { return ptr_.get() == nullptr; } + /*! + * \brief set all the elements in narray to be scalar + * \param scalar the scalar to set + * \return reference of self + */ + NArray &operator=(real_t scalar); + /*! + * \brief elementwise add to current space + * this mutate the current NArray + * \param src the data to add + * \return reference of self + */ + NArray &operator+=(const NArray &src); + /*! + * \brief elementwise subtract from current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator-=(const NArray &src); + /*! + * \brief elementwise multiplication to current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator*=(const NArray &src); + /*! + * \brief elementwise division from current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator/=(const NArray &src); private: /*! \brief the real data chunk that backs NArray */ diff --git a/src/api_registry.cc b/src/api_registry.cc index 43549be55376..151ad9c7ad37 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -4,8 +4,8 @@ namespace mxnet { -NArrayFunRegistry::Entry & -NArrayFunRegistry::Register(const std::string name) { +FunctionRegistry::Entry & +FunctionRegistry::Register(const std::string name) { CHECK(fmap_.count(name) == 0); Entry *e = new Entry(name); fmap_[name] = e; @@ -15,14 +15,14 @@ NArrayFunRegistry::Register(const std::string name) { return *e; } -NArrayFunRegistry::~NArrayFunRegistry() { +FunctionRegistry::~FunctionRegistry() { for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { delete p->second; } } -NArrayFunRegistry *NArrayFunRegistry::Get() { - static NArrayFunRegistry instance; +FunctionRegistry *FunctionRegistry::Get() { + static FunctionRegistry instance; return &instance; } diff --git a/src/narray/narray.cc b/src/narray/narray.cc index abc9b499b993..c4b0553f937d 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -52,6 +52,13 @@ inline NArray BinaryEWiseRet(const NArray &lhs, return ret; } +template +inline NArray &BinaryEWiseApply(NArray *dst, + const NArray &src) { + BinaryEWise(*dst, src, dst); + return *dst; +} + NArray operator+(const NArray &lhs, const NArray &rhs) { return BinaryEWiseRet(lhs, rhs); } @@ -65,7 +72,20 @@ NArray operator/(const NArray &lhs, const NArray &rhs) { return BinaryEWiseRet(lhs, rhs); } +NArray &NArray::operator+=(const NArray &src) { + return BinaryEWiseApply(this, src); +} +NArray &NArray::operator-=(const NArray &src) { + return BinaryEWiseApply(this, src); +} +NArray &NArray::operator*=(const NArray &src) { + return BinaryEWiseApply(this, src); +} +NArray &NArray::operator/=(const NArray &src) { + return BinaryEWiseApply(this, src); +} +// register API function REGISTER_NARRAY_FUN(Plus).set_function(BinaryEWise); REGISTER_NARRAY_FUN(Minus).set_function(BinaryEWise); REGISTER_NARRAY_FUN(Mul).set_function(BinaryEWise); diff --git a/src/narray/narray_op.h b/src/narray/narray_op.h index bbdb3e1e53b3..72abf98fa3d3 100644 --- a/src/narray/narray_op.h +++ b/src/narray/narray_op.h @@ -16,6 +16,7 @@ namespace narray { struct BinaryBase { inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; + CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape"; return lshape; } }; diff --git a/test/api_registry_test.cc b/test/api_registry_test.cc index b361ca2242e0..dd6604c6cdb9 100644 --- a/test/api_registry_test.cc +++ b/test/api_registry_test.cc @@ -3,7 +3,7 @@ #include int main(int argc, char *argv[]) { - auto fadd = mxnet::NArrayFunRegistry::Get()->Find("Plus"); + auto fadd = mxnet::FunctionRegistry::Find("Plus"); printf("f.name=%s\n", fadd->name.c_str()); return 0; } From b0e41fd24e27d411d0ae4f3a62eba9a137c83bf1 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 27 Jun 2015 14:57:11 -0600 Subject: [PATCH 03/11] api --- Makefile | 15 ++-- api/mxnet_api.cc | 160 +++++++++++++++++++++++++++++++++++++++++ api/mxnet_api.h | 16 +++-- include/mxnet/base.h | 8 +++ include/mxnet/narray.h | 11 +++ 5 files changed, 199 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 0aa41cc4fd57..60abe5c232ca 100644 --- a/Makefile +++ b/Makefile @@ -50,12 +50,13 @@ BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o CUOBJ = narray_op_gpu.o operator_gpu.o - +SLIB = api/libmxnet.so +ALIB = api/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a .PHONY: clean all -all: $(OBJ) $(OBJCXX11) $(CUOBJ) $(BIN) +all: $(ALIB) $(SLIB) $(BIN) $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) @@ -71,7 +72,10 @@ operator_gpu.o: src/operator/operator_gpu.cu api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc -test/api_registry_test: test/api_registry_test.cc $(OBJ) $(OBJCXX11) $(CUOBJ) +api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) +api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) + +test/api_registry_test: test/api_registry_test.cc api/libmxnet.a $(BIN) : $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) @@ -85,6 +89,9 @@ $(OBJCXX11) : $(SLIB) : $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) +$(ALIB): + ar cr $@ $+ + $(CUOBJ) : $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $(filter %.cu, $^) @@ -92,5 +99,5 @@ $(CUBIN) : $(NVCC) -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -Xlinker "$(LDFLAGS)" $(filter %.cu %.cpp %.o, $^) clean: - $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) *~ */*~ */*/*~ + $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ cd $(DMLC_CORE); make clean; cd - diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index c38f5dc99092..3c4e52dffacc 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -1,4 +1,164 @@ #include #include +#include #include "./mxnet_api.h" +// NOTE: all functions return 0 upon success +// consider add try/catch block for user error +// handling in the future +using namespace mxnet; + +// macro to guard beginning and end section of all functions +// every function starts with API_BEGIN(); and finishes with API_END(); +#define API_BEGIN() try { +#define API_END() } catch(MXNetException &e) { return MXHandleException(e); } return 0; + +/*! + * \brief handle exception throwed out + * \param e the exception + * \return the return value of API after exception is handled + */ +int MXHandleException(const MXNetException &e) { + return -1; +} + +// NOTE: return value is added in API_END +int MXNArrayCreateNone(NArrayHandle *out) { + API_BEGIN(); + *out = new NArray(); + API_END(); +} + +int MXNArrayCreateShareMem(mx_float *data, + mx_uint *shape, + mx_uint ndim, + NArrayHandle *out) { + API_BEGIN(); + *out = new NArray(TBlob(data, TShape(shape, shape + ndim), + cpu::kDevMask), 0); + API_END(); +} + +int MXNArrayCreate(const mx_uint *shape, + mx_uint ndim, + int dev_mask, + int dev_id, + NArrayHandle *out) { + API_BEGIN(); + *out = new NArray(TShape(shape, shape + ndim), + Context(dev_mask, dev_id)); + API_END(); +} + +int MXNArrayWait(NArrayHandle handle) { + API_BEGIN(); + static_cast(handle)->Wait(); + API_END(); +} + +int MXNArrayWaitAll() { + API_BEGIN(); + DAGEngine::Get()->WaitForAll(); + API_END(); +} + +int MXNArrayFree(NArrayHandle handle) { + API_BEGIN(); + delete static_cast(handle); + API_END(); +} + +int MXNArrayGetShape(NArrayHandle handle, + mx_uint *out_dim, + const mx_uint **out_pdata) { + API_BEGIN(); + NArray *arr = static_cast(handle); + if (!arr->is_none()) { + const TShape &s = arr->shape(); + *out_dim = s.ndim(); + *out_pdata = s.data(); + } else { + *out_dim = 0; + } + API_END(); +} + +int MXNArrayGetData(NArrayHandle handle, + mx_float **out_pdata) { + API_BEGIN(); + NArray *arr = static_cast(handle); + if (!arr->is_none()) { + // TODO: change to exception + CHECK(arr->ctx().dev_mask != cpu::kDevMask); + const TBlob &b = arr->data(); + CHECK(b.CheckContiguous()); + *out_pdata = b.FlatTo2D().dptr_; + } else { + *out_pdata = nullptr; + } + API_END(); +} + +int MXNArrayGetDevice(NArrayHandle handle, + int *out_dev_mask, + int *out_dev_id) { + API_BEGIN(); + NArray *arr = static_cast(handle); + if (!arr->is_none()) { + const Context &ctx = arr->ctx(); + *out_dev_mask = ctx.dev_mask; + *out_dev_id = ctx.dev_id; + } else { + *out_dev_mask = 0; + *out_dev_id = 0; + } + API_END(); +} + +int MXListFunctions(mx_uint *out_size, + FunctionHandle **out_array) { + API_BEGIN(); + auto &vec = FunctionRegistry::List(); + *out_size = static_cast(vec.size()); + *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); + API_END(); +} + +int MXGetFunction(const char *name, + FunctionHandle *out) { + API_BEGIN(); + *out = FunctionRegistry::Find(name); + API_END(); +} + +int MXFuncGetName(FunctionHandle fun, + const char **out_name) { + API_BEGIN(); + auto *f = static_cast(fun); + *out_name = f->name.c_str(); + API_END(); +} + +int MXFuncDescribeArgs(FunctionHandle fun, + mx_uint *num_use_vars, + mx_uint *num_scalars, + mx_uint *num_mutate_vars) { + API_BEGIN(); + auto *f = static_cast(fun); + *num_use_vars = f->num_use_vars; + *num_scalars = f->num_scalars; + *num_mutate_vars = f->num_mutate_vars; + API_END(); +} + +int MXFuncInvoke(FunctionHandle fun, + NArrayHandle *use_vars, + mx_float *scalar_args, + NArrayHandle *mutate_vars) { + API_BEGIN(); + auto *f = static_cast(fun); + (*f)((NArray**)(use_vars), + scalar_args, + (NArray**)(mutate_vars)); + API_END(); +} diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 7c4cb18b9f5d..408781771dd7 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -26,7 +26,7 @@ typedef float mx_float; /*! \brief handle to NArray */ typedef void *NArrayHandle; /*! \brief handle to a mxnet narray function that changes NArray */ -typedef void *FunctionHandle; +typedef const void *FunctionHandle; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to a NArrayOperator */ @@ -105,25 +105,27 @@ MXNET_DLL int MXNArrayFree(NArrayHandle handle); * \param out_pdata pointer holder to get data pointer of the shape * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetShape(NArrayHandle *handle, +MXNET_DLL int MXNArrayGetShape(NArrayHandle handle, mx_uint *out_dim, - mx_uint **out_pdata); + const mx_uint **out_pdata); /*! * \brief get the content of the data in NArray * \param handle the handle to the narray * \param out_pdata pointer holder to get pointer of data * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetData(NArrayHandle *handle, +MXNET_DLL int MXNArrayGetData(NArrayHandle handle, mx_float **out_pdata); /*! * \brief get the device of the NArray * \param handle the handle to the narray - * \param out_device the output device mask + * \param out_dev_mask the output device mask + * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetDevice(NArrayHandle *handle, - int *out_device); +MXNET_DLL int MXNArrayGetDevice(NArrayHandle handle, + int *out_dev_mask, + int *out_dev_id); //-------------------------------- // Part 2: functions on NArray diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 0bc636161ee1..13329518142e 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -5,6 +5,8 @@ */ #ifndef MXNET_BASE_H_ #define MXNET_BASE_H_ +#include +#include #include #include @@ -37,5 +39,11 @@ typedef mshadow::gpu gpu; typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; + +/*! + * \brief exception throwed when the error is caused + */ +struct MXNetException : public std::exception { +}; } // namespace mxnet #endif // MXNET_BASE_H_ diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 8e189a6ee960..f9132c6e31be 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -49,6 +49,12 @@ class NArray { inline const TShape &shape() const { return ptr_->data.shape_; } + /*! + * \return the data TBlob + */ + inline const TBlob &data() const { + return ptr_->data; + } /*! * \return the context of NArray, this function is only valid when the NArray is not empty */ @@ -59,6 +65,11 @@ class NArray { inline bool is_none() const { return ptr_.get() == nullptr; } + /*! \brief wait until the result of the NArray is computed */ + inline void Wait() const { + if (is_none()) return; + DAGEngine::Get()->WaitForVar(ptr_->var); + } /*! * \brief set all the elements in narray to be scalar * \param scalar the scalar to set From 82a7f3588049763573a7bc980cb0d8aa8eae9cd5 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 27 Jun 2015 23:03:53 -0600 Subject: [PATCH 04/11] initial version of api --- .gitignore | 5 ++ Makefile | 2 +- api/mxnet_api.cc | 31 +++++++- api/mxnet_api.h | 10 +++ api/python/mxnet/__init__.py | 11 +++ api/python/mxnet/base.py | 137 +++++++++++++++++++++++++++++++++++ api/python/mxnet/narray.py | 106 +++++++++++++++++++++++++++ api/python/test_python.py | 11 +++ include/mxnet/base.h | 7 -- src/narray/narray.cc | 8 +- test/api_registry_test.cc | 2 +- 11 files changed, 313 insertions(+), 17 deletions(-) create mode 100644 api/python/mxnet/__init__.py create mode 100644 api/python/mxnet/base.py create mode 100644 api/python/mxnet/narray.py create mode 100644 api/python/test_python.py diff --git a/.gitignore b/.gitignore index a63de96ac6d6..8b1cd8d0c906 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,8 @@ dmlc-core mshadow config.mk +*.pyc +.Rhistory +*log +Debug +*suo diff --git a/Makefile b/Makefile index 60abe5c232ca..d67f71fa4b19 100644 --- a/Makefile +++ b/Makefile @@ -99,5 +99,5 @@ $(CUBIN) : $(NVCC) -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -Xlinker "$(LDFLAGS)" $(filter %.cu %.cpp %.o, $^) clean: - $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ + $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 3c4e52dffacc..aeab45275f1a 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -1,3 +1,5 @@ +#include +#include #include #include #include @@ -11,14 +13,35 @@ using namespace mxnet; // macro to guard beginning and end section of all functions // every function starts with API_BEGIN(); and finishes with API_END(); #define API_BEGIN() try { -#define API_END() } catch(MXNetException &e) { return MXHandleException(e); } return 0; +#define API_END() } catch(dmlc::Error &e) { return MXHandleException(e); } return 0; + +/*! + * \brief a helper function for error handling + * will set the last error to be str_set when it is not NULL + * \param str_set the error to set + * \return a pointer message to last error + */ +const char *MXSetGetLastError_(const char *str_set) { + // use last_error to record last error + static thread_local std::string last_error; + if (str_set != NULL) { + last_error = str_set; + } + return last_error.c_str(); +} + +/*! \brief return str message of the last error */ +const char *MXGetLastError() { + return MXSetGetLastError_(NULL); +} /*! * \brief handle exception throwed out * \param e the exception * \return the return value of API after exception is handled */ -int MXHandleException(const MXNetException &e) { +int MXHandleException(const dmlc::Error &e) { + MXSetGetLastError_(e.what()); return -1; } @@ -88,8 +111,8 @@ int MXNArrayGetData(NArrayHandle handle, API_BEGIN(); NArray *arr = static_cast(handle); if (!arr->is_none()) { - // TODO: change to exception - CHECK(arr->ctx().dev_mask != cpu::kDevMask); + CHECK(arr->ctx().dev_mask == cpu::kDevMask) + << "MXNArrayGetData can only be called for NArray on CPU"; const TBlob &b = arr->data(); CHECK(b.CheckContiguous()); *out_pdata = b.FlatTo2D().dptr_; diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 408781771dd7..800797fb5907 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -34,6 +34,16 @@ typedef void *OperatorHandle; /*! \brief handle to a DataIterator */ typedef void *DataIterHandle; +/*! + * \brief return str message of the last error + * all function in this file will return 0 when success + * and -1 when an error occured, + * MXGetLastError can be called to retrieve the error + * + * this function is threadsafe and can be called by different thread + */ +MXNET_DLL const char *MXGetLastError(); + //-------------------------------- // Part 1: NArray creation and deletion //-------------------------------- diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py new file mode 100644 index 000000000000..8fad96d85275 --- /dev/null +++ b/api/python/mxnet/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# coding: utf-8 +"""MXNet: a concise, fast and flexible framework for deep learning + +MXNet is a project that evolves from cxxnet, minerva and purine2. +The interface is designed in collaboration by authors of three projects. + +Version : 0.10 +""" +from narray import NArray +from narray import zeros_shared diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py new file mode 100644 index 000000000000..cdf920741185 --- /dev/null +++ b/api/python/mxnet/base.py @@ -0,0 +1,137 @@ +# coding: utf-8 +""" ctypes library of mxnet and helper functions """ +from __future__ import absolute_import + +import os +import sys +import ctypes +import platform +import numpy as np + +#---------------------------- +# library loading +#---------------------------- +if sys.version_info[0] == 3: + string_types = str, +else: + string_types = basestring, + +class MXNetError(Exception): + pass + + +def _load_lib(): + """load libary by looking at possible path""" + api_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + api_path = os.path.join(api_path, '../../') + dll_path = [api_path] + if os.name == 'nt': + if platform.architecture()[0] == '64bit': + dll_path.append(os.path.join(api_path, '../windows/x64/Release/')) + else: + dll_path.append(os.path.join(api_path, '../windows/Release/')) + if os.name == 'nt': + dll_path = [os.path.join(p, 'mxnet.dll') for p in dll_path] + else: + dll_path = [os.path.join(p, 'libmxnet.so') for p in dll_path] + lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] + if len(dll_path) == 0: + raise MXNetError('cannot find find the files in the candicate path ' + str(dll_path)) + lib = ctypes.cdll.LoadLibrary(lib_path[0]) + + # DMatrix functions + lib.MXGetLastError.restype = ctypes.c_char_p + return lib + + +lib = _load_lib() + +#---------------------------- +# helper function definition +#---------------------------- +def check_call(ret): + """check the return value of C API call + + this function will raise exception when error occurs + """ + if ret != 0: + raise MXNetError(lib.MXGetLastError()); + + +def new_narray_handle(): + """return a new empty handle + + Returns + ------- + a new empty narray handle + """ + h = ctypes.c_void_p() + check_call(lib.MXNArrayCreateNone(ctypes.byref(h))) + return h + +def c_array(ctype, values): + """get ctypes array + + Parameters + ---------- + ctype : ctypes data type + data type of the array we want to convert to + values : tuple like + data content + """ + return (ctype * len(values))(*values) + +def ctypes2numpy(cptr, shape): + """convert a ctypes pointer array to a numpy array. + """ + if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)): + raise RuntimeError('expected float pointer') + res = np.zeros(shape, dtype = np.float32) + if not ctypes.memmove(res.ctypes.data, cptr, res.size * res.strides[-1]): + raise RuntimeError('memmove failed') + return res + +#------------------------------ +# get list of functon pointers +#------------------------------ +class FunctionRegistry: + def __init__(self): + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + check_call(lib.MXListFunctions(ctypes.byref(size), + ctypes.byref(plist))) + hmap = {} + print size.value + for i in range(size.value): + h = plist[i] + name = ctypes.c_char_p() + check_call(lib.MXFuncGetName(h, ctypes.byref(name))) + hmap[name.value] = h + self.__dict__.update(hmap) + +# handle to function registry +op = FunctionRegistry() + + +def invoke(fhandle, used_vars, scalars, mutate_vars): + """invoke a function handle by passing in arguments as tuples + + Parameters + ---------- + fhandle : ctypes.c_void_p + function handle of C API + + used_vars : tuple + tuple of NArray handles + + scalars : tuple + tuple of real number arguments + + mutate_vars : tuple + tuple of NArray handles to mutate + """ + check_call(lib.MXFuncInvoke( + fhandle, + c_array(ctypes.c_void_p, used_vars), + c_array(ctypes.c_float, scalars), + c_array(ctypes.c_void_p, mutate_vars))) diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py new file mode 100644 index 000000000000..6a73f2d9f00f --- /dev/null +++ b/api/python/mxnet/narray.py @@ -0,0 +1,106 @@ +# coding: utf-8 +"""NArray interface of mxnet""" +import ctypes +import numpy as np +from base import lib +from base import op +from base import c_array +from base import ctypes2numpy +from base import invoke +from base import check_call +from base import new_narray_handle +from base import MXNetError + +# function handles +_h_plus = op.plus +_h_minus = op.minus +_h_mul = op.mul +_h_div = op.div + +class NArray(object): + """NArray object in mxnet + + NArray is basic ndarray like data structure in mxnet + """ + def __init__(self, handle): + """initialize a new NArray + + Parameters + ---------- + handle : ctypes.c_void_p + NArray handle of C API + """ + assert isinstance(handle, ctypes.c_void_p) + self.handle = handle + + def __del__(self): + check_call(lib.MXNArrayFree(self.handle)) + + def __lbinary__(self, handle, other): + if isinstance(other, NArray): + hret = new_narray_handle() + invoke(_h_plus, (self.handle, other.handle), (), (hret,)) + return NArray(handle = hret) + else: + raise MXNetError('type ' + str(other) + 'not supported') + + def __add__(self, other): + return self.__lbinary__(_h_plus, other) + + def __sub__(self, other): + return self.__lbinary__(_h_plus, other) + + def wait(self): + """Wait until the data on current NArray is available""" + check_call(lib.MXNArrayWait(self.handle)) + + def get_shape(self): + """Get shape of current NArray + + Returns + ------- + a tuple representing shape of current narray + """ + ndim = ctypes.c_uint() + pdata = ctypes.POINTER(ctypes.c_uint)() + check_call(lib.MXNArrayGetShape( + self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) + return tuple(pdata[i] for i in range(ndim.value)) + + def to_numpy(self): + """Return a copy of numpy NArray + + Returns + ------- + a tuple representing shape of current narray + """ + self.wait() + pdata = ctypes.POINTER(ctypes.c_float)() + check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) + shape = self.get_shape() + print shape + return ctypes2numpy(pdata, shape) + + +def zeros_shared(shape): + """Create a new CPU based narray that shares memory content with a numpy array + + Parameters + ---------- + shape : tuple + shape of the NArray + + Returns + ------- + a new NArray that shares memory with numpy.narray + """ + h = ctypes.c_void_p() + data = np.zeros(shape, dtype = np.float32) + ndim = len(shape) + check_call(lib.MXNArrayCreateShareMem( + data.ctypes.data, + c_array(ctypes.c_uint, shape), + ndim, ctypes.byref(h))) + ret = NArray(handle = h) + ret.numpy = data + return ret diff --git a/api/python/test_python.py b/api/python/test_python.py new file mode 100644 index 000000000000..a1a9141b9ad7 --- /dev/null +++ b/api/python/test_python.py @@ -0,0 +1,11 @@ +import mxnet as mx + +a = mx.zeros_shared((3,4)) +b = mx.zeros_shared((3,4)) +a.numpy[:] = 10 +b.numpy[:] = 11 +print a.numpy +c = b + a + +print c.to_numpy() + diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 13329518142e..f58a7f263f60 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -5,7 +5,6 @@ */ #ifndef MXNET_BASE_H_ #define MXNET_BASE_H_ -#include #include #include #include @@ -39,11 +38,5 @@ typedef mshadow::gpu gpu; typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; - -/*! - * \brief exception throwed when the error is caused - */ -struct MXNetException : public std::exception { -}; } // namespace mxnet #endif // MXNET_BASE_H_ diff --git a/src/narray/narray.cc b/src/narray/narray.cc index c4b0553f937d..598f30e63205 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -86,8 +86,8 @@ NArray &NArray::operator/=(const NArray &src) { } // register API function -REGISTER_NARRAY_FUN(Plus).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(Minus).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(Mul).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(Div).set_function(BinaryEWise); +REGISTER_NARRAY_FUN(plus).set_function(BinaryEWise); +REGISTER_NARRAY_FUN(minus).set_function(BinaryEWise); +REGISTER_NARRAY_FUN(mul).set_function(BinaryEWise); +REGISTER_NARRAY_FUN(div).set_function(BinaryEWise); } // namespace mxnet diff --git a/test/api_registry_test.cc b/test/api_registry_test.cc index dd6604c6cdb9..8e82fad7dc56 100644 --- a/test/api_registry_test.cc +++ b/test/api_registry_test.cc @@ -4,6 +4,6 @@ int main(int argc, char *argv[]) { auto fadd = mxnet::FunctionRegistry::Find("Plus"); - printf("f.name=%s\n", fadd->name.c_str()); + printf("f.name=%s\n", fadd->name.c_str()); return 0; } From 0e30f9b5d4fff3514473bbbff0291983d5131af4 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 27 Jun 2015 23:06:19 -0600 Subject: [PATCH 05/11] remove debug --- api/python/mxnet/base.py | 1 - api/python/mxnet/narray.py | 1 - src/api_registry.cc | 2 -- 3 files changed, 4 deletions(-) diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index cdf920741185..6aaac5896fd5 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -101,7 +101,6 @@ def __init__(self): check_call(lib.MXListFunctions(ctypes.byref(size), ctypes.byref(plist))) hmap = {} - print size.value for i in range(size.value): h = plist[i] name = ctypes.c_char_p() diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py index 6a73f2d9f00f..9007752ba6d8 100644 --- a/api/python/mxnet/narray.py +++ b/api/python/mxnet/narray.py @@ -78,7 +78,6 @@ def to_numpy(self): pdata = ctypes.POINTER(ctypes.c_float)() check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) shape = self.get_shape() - print shape return ctypes2numpy(pdata, shape) diff --git a/src/api_registry.cc b/src/api_registry.cc index 151ad9c7ad37..93e0bd0ee678 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -10,8 +10,6 @@ FunctionRegistry::Register(const std::string name) { Entry *e = new Entry(name); fmap_[name] = e; fun_list_.push_back(e); - // delete me later - LOG(INFO) << "register function " << name; return *e; } From 9947f970c8a6eddc9abef70d994708f55df5ab0b Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 27 Jun 2015 23:12:19 -0600 Subject: [PATCH 06/11] python3 compatibility --- api/python/mxnet/__init__.py | 5 +++-- api/python/mxnet/narray.py | 18 ++++++++++-------- api/python/test_python.py | 4 ++-- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py index 8fad96d85275..395114679af9 100644 --- a/api/python/mxnet/__init__.py +++ b/api/python/mxnet/__init__.py @@ -7,5 +7,6 @@ Version : 0.10 """ -from narray import NArray -from narray import zeros_shared +from __future__ import absolute_import +from .narray import NArray +from .narray import zeros_shared diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py index 9007752ba6d8..83ff6ce2599b 100644 --- a/api/python/mxnet/narray.py +++ b/api/python/mxnet/narray.py @@ -1,15 +1,17 @@ # coding: utf-8 """NArray interface of mxnet""" +from __future__ import absolute_import + import ctypes import numpy as np -from base import lib -from base import op -from base import c_array -from base import ctypes2numpy -from base import invoke -from base import check_call -from base import new_narray_handle -from base import MXNetError +from .base import lib +from .base import op +from .base import c_array +from .base import ctypes2numpy +from .base import invoke +from .base import check_call +from .base import new_narray_handle +from .base import MXNetError # function handles _h_plus = op.plus diff --git a/api/python/test_python.py b/api/python/test_python.py index a1a9141b9ad7..54c6a5c92bc2 100644 --- a/api/python/test_python.py +++ b/api/python/test_python.py @@ -4,8 +4,8 @@ b = mx.zeros_shared((3,4)) a.numpy[:] = 10 b.numpy[:] = 11 -print a.numpy +print(a.numpy) c = b + a -print c.to_numpy() +print(c.to_numpy()) From ca89d4c16a1d538971bd14885ab513270ca0b962 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sun, 28 Jun 2015 10:59:05 -0600 Subject: [PATCH 07/11] clean --- api/python/mxnet/base.py | 78 +++++++++++++++++++++++------------- api/python/mxnet/narray.py | 82 ++++++++++++++++++++++++++------------ api/python/test_python.py | 2 +- 3 files changed, 107 insertions(+), 55 deletions(-) diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index 6aaac5896fd5..1be8ed274ec6 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -16,15 +16,17 @@ else: string_types = basestring, + class MXNetError(Exception): + """Error that will be throwed by all mxnet functions""" pass def _load_lib(): - """load libary by looking at possible path""" - api_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - api_path = os.path.join(api_path, '../../') - dll_path = [api_path] + """load libary by searching possible path""" + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + api_path = os.path.join(curr_path, '../../') + dll_path = [api_path, curr_path] if os.name == 'nt': if platform.architecture()[0] == '64bit': dll_path.append(os.path.join(api_path, '../windows/x64/Release/')) @@ -44,47 +46,67 @@ def _load_lib(): return lib +# library instance of mxnet lib = _load_lib() +# type definitions +mx_uint = ctypes.c_uint +mx_float = ctypes.c_float +NArrayHandle = ctypes.c_void_p +FunctionHandle = ctypes.c_void_p + #---------------------------- # helper function definition #---------------------------- + def check_call(ret): - """check the return value of C API call + """Check the return value of C API call - this function will raise exception when error occurs + This function will raise exception when error occurs. + Wrap every API call with this function + + Parameters + ---------- + ret : int + return value from API calls """ if ret != 0: raise MXNetError(lib.MXGetLastError()); -def new_narray_handle(): - """return a new empty handle - - Returns - ------- - a new empty narray handle - """ - h = ctypes.c_void_p() - check_call(lib.MXNArrayCreateNone(ctypes.byref(h))) - return h - def c_array(ctype, values): - """get ctypes array + """Create ctypes array from a python array Parameters ---------- ctype : ctypes data type data type of the array we want to convert to - values : tuple like + values : tuple or list data content + + Returns + ------- + created ctypes array """ return (ctype * len(values))(*values) + def ctypes2numpy(cptr, shape): - """convert a ctypes pointer array to a numpy array. + """Convert a ctypes pointer to a numpy array. + + Parameters + ---------- + cptr : ctypes.POINTER(mx_float) + pointer to the memory region + + shape : tuple + shape of target narray + + Returns + ------- + a copy of nupy array : numpy array """ - if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)): + if not isinstance(cptr, ctypes.POINTER(mx_float)): raise RuntimeError('expected float pointer') res = np.zeros(shape, dtype = np.float32) if not ctypes.memmove(res.ctypes.data, cptr, res.size * res.strides[-1]): @@ -94,7 +116,7 @@ def ctypes2numpy(cptr, shape): #------------------------------ # get list of functon pointers #------------------------------ -class FunctionRegistry: +class _FunctionRegistry: def __init__(self): plist = ctypes.POINTER(ctypes.c_void_p)() size = ctypes.c_uint() @@ -109,15 +131,15 @@ def __init__(self): self.__dict__.update(hmap) # handle to function registry -op = FunctionRegistry() +op = _FunctionRegistry() def invoke(fhandle, used_vars, scalars, mutate_vars): - """invoke a function handle by passing in arguments as tuples + """Invoke a function handle by passing in arguments as tuples Parameters ---------- - fhandle : ctypes.c_void_p + fhandle : FunctionHandle function handle of C API used_vars : tuple @@ -131,6 +153,6 @@ def invoke(fhandle, used_vars, scalars, mutate_vars): """ check_call(lib.MXFuncInvoke( fhandle, - c_array(ctypes.c_void_p, used_vars), - c_array(ctypes.c_float, scalars), - c_array(ctypes.c_void_p, mutate_vars))) + c_array(NArrayHandle, used_vars), + c_array(mx_float, scalars), + c_array(NArrayHandle, mutate_vars))) diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py index 83ff6ce2599b..f76d74c1b8ac 100644 --- a/api/python/mxnet/narray.py +++ b/api/python/mxnet/narray.py @@ -7,17 +7,23 @@ from .base import lib from .base import op from .base import c_array +from .base import mx_uint, mx_float, NArrayHandle from .base import ctypes2numpy from .base import invoke from .base import check_call -from .base import new_narray_handle from .base import MXNetError -# function handles -_h_plus = op.plus -_h_minus = op.minus -_h_mul = op.mul -_h_div = op.div +def _new_empty_handle(): + """Return a new empty handle + + Empty handle is only used to hold results + Returns + ------- + a new empty narray handle + """ + h = NArrayHandle() + check_call(lib.MXNArrayCreateNone(ctypes.byref(h))) + return h class NArray(object): """NArray object in mxnet @@ -29,46 +35,71 @@ def __init__(self, handle): Parameters ---------- - handle : ctypes.c_void_p + handle : NArrayHandle NArray handle of C API """ - assert isinstance(handle, ctypes.c_void_p) + assert isinstance(handle, NArrayHandle) self.handle = handle def __del__(self): check_call(lib.MXNArrayFree(self.handle)) - def __lbinary__(self, handle, other): + def __add__(self, other): + hret = _new_empty_handle() if isinstance(other, NArray): - hret = new_narray_handle() - invoke(_h_plus, (self.handle, other.handle), (), (hret,)) - return NArray(handle = hret) + invoke(op.plus, (other.handle, self.handle), (), (hret,)) else: - raise MXNetError('type ' + str(other) + 'not supported') + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) - def __add__(self, other): - return self.__lbinary__(_h_plus, other) + def __radd__(self, other): + return self.__add__(other) def __sub__(self, other): - return self.__lbinary__(_h_plus, other) + hret = _new_empty_handle() + if isinstance(other, NArray): + invoke(op.minus, (other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) + + def __mul__(self, other): + hret = _new_empty_handle() + if isinstance(other, NArray): + invoke(op.mul, (other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) + + def __rmul__(self, other): + return self.__mul__(other) + + def __div__(self, other): + hret = _new_empty_handle() + if isinstance(other, NArray): + invoke(op.div, (other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) def wait(self): """Wait until the data on current NArray is available""" check_call(lib.MXNArrayWait(self.handle)) - def get_shape(self): + @property + def shape(self): """Get shape of current NArray Returns ------- a tuple representing shape of current narray """ - ndim = ctypes.c_uint() - pdata = ctypes.POINTER(ctypes.c_uint)() + ndim = mx_uint() + pdata = ctypes.POINTER(mx_uint)() check_call(lib.MXNArrayGetShape( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[i] for i in range(ndim.value)) - + def to_numpy(self): """Return a copy of numpy NArray @@ -77,10 +108,9 @@ def to_numpy(self): a tuple representing shape of current narray """ self.wait() - pdata = ctypes.POINTER(ctypes.c_float)() - check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) - shape = self.get_shape() - return ctypes2numpy(pdata, shape) + pdata = ctypes.POINTER(mx_float)() + check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) + return ctypes2numpy(pdata, self.shape) def zeros_shared(shape): @@ -95,12 +125,12 @@ def zeros_shared(shape): ------- a new NArray that shares memory with numpy.narray """ - h = ctypes.c_void_p() + h = NArrayHandle() data = np.zeros(shape, dtype = np.float32) ndim = len(shape) check_call(lib.MXNArrayCreateShareMem( data.ctypes.data, - c_array(ctypes.c_uint, shape), + c_array(mx_uint, shape), ndim, ctypes.byref(h))) ret = NArray(handle = h) ret.numpy = data diff --git a/api/python/test_python.py b/api/python/test_python.py index 54c6a5c92bc2..5d6385652b9a 100644 --- a/api/python/test_python.py +++ b/api/python/test_python.py @@ -5,7 +5,7 @@ a.numpy[:] = 10 b.numpy[:] = 11 print(a.numpy) -c = b + a +c = b / a print(c.to_numpy()) From 8424780a2b36db74f4980868637ddb4ff8a95a72 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sun, 28 Jun 2015 14:36:12 -0600 Subject: [PATCH 08/11] make narray working --- api/mxnet_api.cc | 10 +-- api/mxnet_api.h | 11 +-- api/python/mxnet/__init__.py | 2 +- api/python/mxnet/base.py | 34 +++++++-- api/python/mxnet/context.py | 50 +++++++++++++ api/python/mxnet/narray.py | 97 +++++++++++++++++++------ api/python/test_python.py | 24 +++++-- include/mxnet/api_registry.h | 20 +++++- include/mxnet/dag_engine.h | 5 +- include/mxnet/narray.h | 48 +++++++++---- src/dag_engine/simple_engine.cc | 19 ++++- src/narray/narray.cc | 123 +++++++++++++++++++++++--------- src/narray/narray_op-inl.h | 9 +-- src/narray/narray_op.h | 8 ++- src/narray/narray_op_cpu.cc | 12 ++++ src/narray/narray_op_gpu.cu | 45 ++++++++++++ src/storage/storage.cc | 20 ++++-- 17 files changed, 433 insertions(+), 104 deletions(-) create mode 100644 api/python/mxnet/context.py diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index aeab45275f1a..3fefa4473570 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -66,10 +66,12 @@ int MXNArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_mask, int dev_id, + int delay_alloc, NArrayHandle *out) { API_BEGIN(); *out = new NArray(TShape(shape, shape + ndim), - Context(dev_mask, dev_id)); + Context(dev_mask, dev_id), + delay_alloc != 0); API_END(); } @@ -122,9 +124,9 @@ int MXNArrayGetData(NArrayHandle handle, API_END(); } -int MXNArrayGetDevice(NArrayHandle handle, - int *out_dev_mask, - int *out_dev_id) { +int MXNArrayGetContext(NArrayHandle handle, + int *out_dev_mask, + int *out_dev_id) { API_BEGIN(); NArray *arr = static_cast(handle); if (!arr->is_none()) { diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 800797fb5907..517de39f180a 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -81,6 +81,8 @@ MXNET_DLL int MXNArrayCreateShareMem(mx_float *data, * \param ndim the dimension of the shape * \param dev_mask device mask, specify device we want to take * \param dev_id the device id of the specific device + * \param delay_alloc whether to delay allocation until + * the narray is first mutated * \param out the returning handle * \return 0 when success, -1 when failure happens */ @@ -88,6 +90,7 @@ MXNET_DLL int MXNArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_mask, int dev_id, + int delay_alloc, NArrayHandle *out); /*! * \brief wait until all the operation with respect NArray @@ -127,15 +130,15 @@ MXNET_DLL int MXNArrayGetShape(NArrayHandle handle, MXNET_DLL int MXNArrayGetData(NArrayHandle handle, mx_float **out_pdata); /*! - * \brief get the device of the NArray + * \brief get the context of the NArray * \param handle the handle to the narray * \param out_dev_mask the output device mask * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetDevice(NArrayHandle handle, - int *out_dev_mask, - int *out_dev_id); +MXNET_DLL int MXNArrayGetContext(NArrayHandle handle, + int *out_dev_mask, + int *out_dev_id); //-------------------------------- // Part 2: functions on NArray diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py index 395114679af9..36f7732167a4 100644 --- a/api/python/mxnet/__init__.py +++ b/api/python/mxnet/__init__.py @@ -9,4 +9,4 @@ """ from __future__ import absolute_import from .narray import NArray -from .narray import zeros_shared +from .context import Context, current_context diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index 1be8ed274ec6..581778ce6fb8 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -74,6 +74,22 @@ def check_call(ret): raise MXNetError(lib.MXGetLastError()); +def c_str(string): + """Create ctypes char * from a python string + + Parameters + ---------- + string : string type + python string + + Returns + ------- + a char pointer that can be passed to C API + """ + + return ctypes.c_char_p(string.encode('utf-8')) + + def c_array(ctype, values): """Create ctypes array from a python array @@ -81,6 +97,7 @@ def c_array(ctype, values): ---------- ctype : ctypes data type data type of the array we want to convert to + values : tuple or list data content @@ -91,8 +108,10 @@ def c_array(ctype, values): return (ctype * len(values))(*values) -def ctypes2numpy(cptr, shape): - """Convert a ctypes pointer to a numpy array. +def ctypes2numpy_shared(cptr, shape): + """Convert a ctypes pointer to a numpy array + + The result numpy array shares the memory with the pointer Parameters ---------- @@ -104,14 +123,15 @@ def ctypes2numpy(cptr, shape): Returns ------- - a copy of nupy array : numpy array + a numpy array : numpy array """ if not isinstance(cptr, ctypes.POINTER(mx_float)): raise RuntimeError('expected float pointer') - res = np.zeros(shape, dtype = np.float32) - if not ctypes.memmove(res.ctypes.data, cptr, res.size * res.strides[-1]): - raise RuntimeError('memmove failed') - return res + size = 1 + for s in shape: + size *= s + dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) + return np.frombuffer(dbuffer, dtype = np.float32).reshape(shape) #------------------------------ # get list of functon pointers diff --git a/api/python/mxnet/context.py b/api/python/mxnet/context.py new file mode 100644 index 000000000000..a440e310ce30 --- /dev/null +++ b/api/python/mxnet/context.py @@ -0,0 +1,50 @@ +# coding: utf-8 +""" code for context management """ +from __future__ import absolute_import + +class Context: + """Context representing device and device id in mxnet""" + # static class variable + default_ctx = None + devmask2type = { 1: 'cpu', 2: 'gpu'} + devtype2mask = {'cpu': 1, 'gpu': 2 } + + def __init__(self, device_type, device_id = 0): + """Constructing a context + + Parameters + ---------- + device_type : str (can be 'cpu' or 'gpu') + a string representing the device type + + device_id : int (default=0) + the device id of the device, needed for GPU + """ + self.device_mask = Context.devtype2mask[device_type] + self.device_id = device_id + + @property + def device_type(self): + return Context.devmask2type[self.device_mask] + + def __str__(self): + return 'Context(device_type=%s, device_id=%d)' % ( + self.device_type, self.device_id) + + def __repr__(self): + return self.__str__() + + def __enter__(self): + self._old_ctx = Context.default_ctx + Context.default_ctx = self + return self + + def __exit__(self, type, value, trace): + Context.default_ctx= self._old_ctx + +# initialize the default context in Context +Context.default_ctx = Context('cpu', 0) + +def current_context(): + """Return the current context""" + return Context.default_ctx diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py index f76d74c1b8ac..eca23f2ad3c1 100644 --- a/api/python/mxnet/narray.py +++ b/api/python/mxnet/narray.py @@ -8,15 +8,17 @@ from .base import op from .base import c_array from .base import mx_uint, mx_float, NArrayHandle -from .base import ctypes2numpy +from .base import ctypes2numpy_shared from .base import invoke from .base import check_call from .base import MXNetError +from .context import Context def _new_empty_handle(): """Return a new empty handle + + Empty handle can be used to hold result - Empty handle is only used to hold results Returns ------- a new empty narray handle @@ -25,6 +27,24 @@ def _new_empty_handle(): check_call(lib.MXNArrayCreateNone(ctypes.byref(h))) return h +def _new_alloc_handle(shape, ctx, delay_alloc): + """Return a new handle with specified shape, context + + Empty handle is only used to hold results + Returns + ------- + a new empty narray handle + """ + h = NArrayHandle() + check_call(lib.MXNArrayCreate( + c_array(mx_uint, shape), + len(shape), + ctx.device_mask, + ctx.device_id, + int(delay_alloc), + ctypes.byref(h))) + return h + class NArray(object): """NArray object in mxnet @@ -100,21 +120,64 @@ def shape(self): self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[i] for i in range(ndim.value)) - def to_numpy(self): - """Return a copy of numpy NArray + @property + def context(self): + """Get context of current NArray Returns ------- - a tuple representing shape of current narray + the context of current NArray + """ + dev_mask = ctypes.c_int() + dev_id = ctypes.c_int() + check_call(lib.MXNArrayGetContext( + self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id))) + return Context(Context.devmask2type[dev_mask.value], dev_id.value) + + @property + def numpy(self): + """Return a numpy representation of current array + + This array have to sit on CPU + + Returns + ------- + a numpy array view """ self.wait() pdata = ctypes.POINTER(mx_float)() - check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) - return ctypes2numpy(pdata, self.shape) - + check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) + return ctypes2numpy_shared(pdata, self.shape) + + def copyto(self, other): + """copy the content of current array to othe + + When other is NArray, the content is copied over. + When other is a Context, a new NArray in the context + will be created as target + + Parameters + ---------- + other : NArray or Context + another narray we want to copy to, + or target context we want copy the data to + + Returns + ------- + the copy target NArray + """ + if isinstance(other, NArray): + invoke(op.copy, (self.handle,), (), (other.handle,)) + return other + elif isinstance(other, Context): + hret = _new_alloc_handle(self.shape, other, True) + invoke(op.copy, (self.handle,), (), (hret,)) + return NArray(handle = hret) + else: + raise MXNetError('copyto do not support type ' + type(other)) -def zeros_shared(shape): - """Create a new CPU based narray that shares memory content with a numpy array +def create(shape, ctx = Context.default_ctx): + """Create a new NArray, with specified shape Parameters ---------- @@ -123,15 +186,7 @@ def zeros_shared(shape): Returns ------- - a new NArray that shares memory with numpy.narray + a new NArray """ - h = NArrayHandle() - data = np.zeros(shape, dtype = np.float32) - ndim = len(shape) - check_call(lib.MXNArrayCreateShareMem( - data.ctypes.data, - c_array(mx_uint, shape), - ndim, ctypes.byref(h))) - ret = NArray(handle = h) - ret.numpy = data - return ret + return NArray(handle = _new_alloc_handle(shape, ctx, False)) + diff --git a/api/python/test_python.py b/api/python/test_python.py index 5d6385652b9a..f9e8003c6e06 100644 --- a/api/python/test_python.py +++ b/api/python/test_python.py @@ -1,11 +1,27 @@ import mxnet as mx -a = mx.zeros_shared((3,4)) -b = mx.zeros_shared((3,4)) +a = mx.narray.create((3000,4000)) +b = mx.narray.create((3000,4000)) a.numpy[:] = 10 b.numpy[:] = 11 print(a.numpy) -c = b / a -print(c.to_numpy()) +c = b * a + +print(c.context) + +d = c.copyto(mx.Context('cpu', 0)) + +print(d.numpy) + +with mx.Context('gpu', 0) as ctx: + # gpu operations + print mx.current_context() + print ctx + a_gpu = a.copyto(ctx) + b_gpu = b.copyto(ctx) + c_gpu = b * a + +d_cpu = c_gpu.copyto(mx.current_context()) +print d_cpu.numpy diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index c6cbdb6db89b..253eef1574a1 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -87,12 +87,28 @@ class FunctionRegistry { inline Entry &set_function(void fbinary(const NArray &lhs, const NArray &rhs, NArray *out)) { - body = [fbinary] (NArray **used_vars, real_t *s, NArray **mutate_vars) { + body = [fbinary] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); }; num_use_vars = 2; num_mutate_vars = 1; return *this; - } + } + /*! + * \brief set the function body to a unary NArray function + * this will also auto set the parameters correctly + * \param unary function body to set + * \return ref to the registered entry, used to set properties + */ + inline Entry &set_function(void funary(const NArray &src, + NArray *out)) { + body = [funary] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { + funary(*used_vars[0], mutate_vars[0]); + }; + num_use_vars = 1; num_mutate_vars = 1; + return *this; + } /*! * \brief invoke the function * \param use_vars variables used by the function diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index 2dd8682aadfc..0f2ed61b71bc 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -78,9 +78,12 @@ class DAGEngine { * depending on var is completed * * \param delete_fun a function that will be called after var is deleted + * \param exec_ctx execution context * \param var the variable to be deleted */ - virtual void PushDelete(Op delete_fun, Variable var) = 0; + virtual void PushDelete(Op delete_fun, + Context exec_ctx, + Variable var) = 0; /*! * \brief allocate a new variable, the variable can then * be used to schedul the operation concurrently via dependency patterns diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index f9132c6e31be..177018c9fcaf 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -29,9 +29,11 @@ class NArray { * \brief constructing a new dynamic NArray * \param shape the shape of array * \param ctx context of NArray + * \param delay_alloc whether delay the allocation */ - NArray(const TShape &shape, Context ctx) - : ptr_(new Chunk(shape, ctx, false)) { + NArray(const TShape &shape, Context ctx, + bool delay_alloc = false) + : ptr_(new Chunk(shape, ctx, delay_alloc)) { } /*! * \brief constructing a static NArray that shares data with TBlob @@ -104,6 +106,17 @@ class NArray { * \return reference of self */ NArray &operator/=(const NArray &src); + /*! + * \brief return transpose of current NArray + * \return a new transposed NArray + */ + NArray T() const; + /*! + * \brief return a new copy this NArray + * \param ctx the new context of this NArray + * \return the new copy + */ + NArray Copy(Context ctx) const; private: /*! \brief the real data chunk that backs NArray */ @@ -152,31 +165,38 @@ class NArray { /*! \brief destructor */ ~Chunk() { if (static_data) { - DAGEngine::Get()->PushDelete([](RunContext s) {}, var); + DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; StorageManager::Handle h = this->shandle; DAGEngine::Get()->PushDelete([h](RunContext s) { StorageManager::Get()->Free(h); - }, var); + }, shandle.ctx, var); } } }; /*! \brief internal data of NArray */ std::shared_ptr ptr_; - /*! - * \brief constructing a new dynamic NArray - * \param shape the shape of array - * \param ctx context of NArray - * \param delay_alloc whether delay the allocation - */ - NArray(const TShape &shape, Context ctx, bool delay_alloc) - : ptr_(new Chunk(shape, ctx, delay_alloc)) { - } // add friend to helper functions + friend void CopyFromTo(const NArray &from, NArray *to); + template + friend void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out); template - friend void BinaryEWise(const NArray &lhs, const NArray &rhs, NArray *out); + friend void UnaryOp(const NArray &lhs, const NArray &rhs, NArray *out); }; + +/*! + * \brief issue an copy operation from one NArray to another + * the two narray can sit on different devices + * this operation will be scheduled by the engine + * + * NOTE: this function name explicitly marks the order of from and to + * due to different possible convention carried by copy function + * \param from the narray we want to copy data from + * \param to the target narray + */ +void CopyFromTo(const NArray &from, NArray *to); + /*! * \brief elementwise add * \param lhs left operand diff --git a/src/dag_engine/simple_engine.cc b/src/dag_engine/simple_engine.cc index 9ea42e979735..d38a2daba63a 100644 --- a/src/dag_engine/simple_engine.cc +++ b/src/dag_engine/simple_engine.cc @@ -3,6 +3,7 @@ namespace mxnet { class SimpleEngine : public DAGEngine { public: + virtual void Push(AsyncOp exec_fun, Context exec_ctx, const std::vector &use_vars, @@ -14,10 +15,18 @@ class SimpleEngine : public DAGEngine { Context exec_ctx, const std::vector &use_vars, const std::vector &mutate_vars) { - exec_fun(RunContext()); + if (exec_ctx.dev_mask == gpu::kDevMask) { + ctx_.stream = &stream; + mshadow::SetDevice(exec_ctx.dev_id); + exec_fun(ctx_); + } else { + exec_fun(ctx_); + } } - virtual void PushDelete(Op delete_fun, Variable var) { - delete_fun(RunContext()); + virtual void PushDelete(Op delete_fun, + Context exec_ctx, + Variable var) { + this->Push(delete_fun, exec_ctx, {}, {var}); } virtual Variable NewVar() { // in practice return a ptr to a cell @@ -25,6 +34,10 @@ class SimpleEngine : public DAGEngine { // use ptr directly instead of ID because this avoids an indirect mapping return NULL; } + + private: + RunContext ctx_; + mshadow::Stream stream; }; // implements the singleton factory DAGEngine* DAGEngine::Get() { diff --git a/src/narray/narray.cc b/src/narray/narray.cc index 598f30e63205..8ec3c25f8979 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -6,16 +6,16 @@ namespace mxnet { /*! - * \brief run a binary operation, returning a new dynamically allocated NArray + * \brief run a binary operation * \param lhs left operand * \param rhs right operand * \param out the output narray * \param binary_op the real */ template -inline void BinaryEWise(const NArray &lhs, - const NArray &rhs, - NArray *out) { +inline void BinaryOp(const NArray &lhs, + const NArray &rhs, + NArray *out) { CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; // if out is none, allocate space if (out->is_none()) { @@ -28,66 +28,121 @@ inline void BinaryEWise(const NArray &lhs, // important: callback must always capture by value NArray ret = *out; // redirect everything to mshadow operations - DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - switch (lhs.ctx().dev_mask) { - case cpu::kDevMask: - narray::Eval(lhs.ptr_->data, rhs.ptr_->data, ret.ptr_->data, ctx); - break; + switch (lhs.ctx().dev_mask) { + case cpu::kDevMask: + DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Eval(lhs.ptr_->data, rhs.ptr_->data, &ret.ptr_->data, ctx); + }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + break; #if MXNET_USE_CUDA - case gpu::kDevMask: - narray::Eval(lhs.ptr_->data, rhs.ptr_->data, ret.ptr_->data, ctx); - break; + case gpu::kDevMask: + DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Eval(lhs.ptr_->data, rhs.ptr_->data, &ret.ptr_->data, ctx); + }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + break; #endif - default: LOG(FATAL) << "GPU is not enabled"; - } - }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + default: LOG(FATAL) << "GPU is not enabled"; + } +} + +void CopyFromTo(const NArray &from, NArray *to) { + CHECK(from.shape() == to->shape()) + << "operands shape mismatch"; + CHECK(from.shape().ndim() != 0) + << "source operands have zero dimension shape"; + // important: callback must always capture by value + NArray ret = *to; + int a = from.ctx().dev_mask; + int b = to->ctx().dev_mask; + if (a == cpu::kDevMask && b == cpu::kDevMask) { + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); + } else if (a == cpu::kDevMask && b == gpu::kDevMask) { +#if MXNET_USE_CUDA + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, ret.ctx(), {from.ptr_->var}, {ret.ptr_->var}); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else if (a == gpu::kDevMask && b == cpu::kDevMask) { +#if MXNET_USE_CUDA + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else if (a == gpu::kDevMask && b == gpu::kDevMask) { +#if MXNET_USE_CUDA + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else { + LOG(FATAL) << "unknown device mask"; + } } template -inline NArray BinaryEWiseRet(const NArray &lhs, - const NArray &rhs) { +inline NArray BinaryOpRet(const NArray &lhs, + const NArray &rhs) { NArray ret; - BinaryEWise(lhs, rhs, &ret); + BinaryOp(lhs, rhs, &ret); return ret; } template -inline NArray &BinaryEWiseApply(NArray *dst, - const NArray &src) { - BinaryEWise(*dst, src, dst); +inline NArray &BinaryOpApply(NArray *dst, + const NArray &src) { + BinaryOp(*dst, src, dst); return *dst; } NArray operator+(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray operator-(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray operator*(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray operator/(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray &NArray::operator+=(const NArray &src) { - return BinaryEWiseApply(this, src); + return BinaryOpApply(this, src); } NArray &NArray::operator-=(const NArray &src) { - return BinaryEWiseApply(this, src); + return BinaryOpApply(this, src); } NArray &NArray::operator*=(const NArray &src) { - return BinaryEWiseApply(this, src); + return BinaryOpApply(this, src); } NArray &NArray::operator/=(const NArray &src) { - return BinaryEWiseApply(this, src); + return BinaryOpApply(this, src); } // register API function -REGISTER_NARRAY_FUN(plus).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(minus).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(mul).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(div).set_function(BinaryEWise); +REGISTER_NARRAY_FUN(plus).set_function(BinaryOp); +REGISTER_NARRAY_FUN(minus).set_function(BinaryOp); +REGISTER_NARRAY_FUN(mul).set_function(BinaryOp); +REGISTER_NARRAY_FUN(div).set_function(BinaryOp); + +REGISTER_NARRAY_FUN(copy).set_function(CopyFromTo); } // namespace mxnet diff --git a/src/narray/narray_op-inl.h b/src/narray/narray_op-inl.h index 9891d9a993d0..dd0660336dbe 100644 --- a/src/narray/narray_op-inl.h +++ b/src/narray/narray_op-inl.h @@ -4,8 +4,8 @@ #ifndef DECL_BINARY #define DECL_BINARY(XPU, OP, FUN) \ template<> \ - void Eval(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -19,10 +19,11 @@ namespace mxnet { namespace narray { // true implementation template -inline void Eval_(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx) { +inline void Eval_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = static_cast*>(ctx.stream); - ret.FlatTo2D(s) + ret->FlatTo2D(s) = F(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); } diff --git a/src/narray/narray_op.h b/src/narray/narray_op.h index 72abf98fa3d3..cf827268254f 100644 --- a/src/narray/narray_op.h +++ b/src/narray/narray_op.h @@ -34,7 +34,13 @@ struct Div : public BinaryBase { typedef mshadow::op::div mshadow_op; }; template -void Eval(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx); +void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); + +// copy function when only cpu is involved +template +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx); } // namespace narray } // namespace mxnet diff --git a/src/narray/narray_op_cpu.cc b/src/narray/narray_op_cpu.cc index 9e59be609688..b6c7014964ad 100644 --- a/src/narray/narray_op_cpu.cc +++ b/src/narray/narray_op_cpu.cc @@ -1,3 +1,15 @@ // this will be invoked by gcc and compile CPU version #include "./narray_op.h" #include "./narray_op-inl.h" + +namespace mxnet { +namespace narray { +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D()); +} +} // namespace narray +} // namespace mxnet diff --git a/src/narray/narray_op_gpu.cu b/src/narray/narray_op_gpu.cu index 335be54c27ca..571757e41ee8 100644 --- a/src/narray/narray_op_gpu.cu +++ b/src/narray/narray_op_gpu.cu @@ -1,3 +1,48 @@ // this will be invoked by nvcc and compile GPU version +#include #include "./narray_op.h" #include "./narray_op-inl.h" + +namespace mxnet { +namespace narray { +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D(), + static_cast*>(ctx.stream)); +} + +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D(), + static_cast*>(ctx.stream)); +} + +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + if (from_ctx.dev_id == to_ctx.dev_id) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D(), + static_cast*>(ctx.stream)); + } else { + CHECK(from.CheckContiguous() && to->CheckContiguous()) + << "copy across only support continugous memory"; + mshadow::Stream *s = static_cast*>(ctx.stream); + CHECK(s != NULL) << "need stream in GPU context"; + cudaMemcpyPeerAsync(to->dptr_, + to_ctx.dev_id, + from.dptr_, + from_ctx.dev_id, + from.shape_.Size() * sizeof(real_t), + s->stream_); + } +} +} // namespace narray +} // namespace mxnet diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 342c898801b7..ce24a94e8ade 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,3 +1,4 @@ +#include #include namespace mxnet { class NaiveStorageManager : public StorageManager { @@ -9,14 +10,25 @@ class NaiveStorageManager : public StorageManager { StorageManager::Handle NaiveStorageManager::Alloc(size_t size, Context ctx) { Handle hd; - hd.dptr = new char[size]; hd.ctx = ctx; - hd.handle_ = NULL; + hd.handle_ = NULL; + if (ctx.dev_mask == cpu::kDevMask) { + cudaMallocHost(&hd.dptr, size); + } else { +#if MXNET_USE_CUDA + cudaMalloc(&hd.dptr, size); +#endif + } return hd; } void NaiveStorageManager::Free(StorageManager::Handle handle) { - char *dptr = static_cast(handle.dptr); - delete [] dptr; + if (handle.ctx.dev_mask == cpu::kDevMask) { + cudaFreeHost(handle.dptr); + } else { +#if MXNET_USE_CUDA + cudaFree(handle.dptr); +#endif + } } StorageManager *StorageManager::Get() { static NaiveStorageManager inst; From 438ab2b0297f2c36ab3e11c9d5bde7da555cc0cb Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Sun, 28 Jun 2015 21:11:06 -0400 Subject: [PATCH 09/11] threaded engine draft --- .gitignore | 5 + Makefile | 3 +- src/common/concurrent_blocking_queue.h | 79 +++++++++++ src/common/spin_lock.h | 45 +++++++ src/dag_engine/threaded_engine.cc | 179 +++++++++++++++++++++++++ test/test_threaded_engine.cc | 9 ++ 6 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 src/common/concurrent_blocking_queue.h create mode 100644 src/common/spin_lock.h create mode 100644 src/dag_engine/threaded_engine.cc create mode 100644 test/test_threaded_engine.cc diff --git a/.gitignore b/.gitignore index a63de96ac6d6..3292dfc4e309 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,8 @@ dmlc-core mshadow config.mk + +# vim +*.swp +*.swo +*.swn diff --git a/Makefile b/Makefile index b159e0bc9429..49dde6b7b687 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,7 @@ ifneq ($(ADD_LDFLAGS), NONE) endif OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o -OBJCXX11 = engine.o narray.o +OBJCXX11 = engine.o narray.o threaded_engine.o CUOBJ = narray_op_gpu.o operator_gpu.o LIB_DEP = $(DMLC_CORE)/libdmlc.a @@ -62,6 +62,7 @@ $(DMLC_CORE)/libdmlc.a: storage.o: src/storage/storage.cc engine.o: src/dag_engine/simple_engine.cc +threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h new file mode 100644 index 000000000000..aab39895b119 --- /dev/null +++ b/src/common/concurrent_blocking_queue.h @@ -0,0 +1,79 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +template class ConcurrentBlockingQueue { + const static int BUSY_LOOP = 1000; + public: + ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { + } + void Push(const T& e) { + std::lock_guard lock(mutex_); + has_elmt_ = true; + queue_.push_back(e); + if (queue_.size() == 1) { + cv_.notify_all(); + } + } + bool Pop(T& rv) { + for (int i = 0; i < BUSY_LOOP; i++) { + if (has_elmt_) { + std::lock_guard lock(mutex_); + if (!has_elmt_) { + assert(queue_.empty()); + continue; + } + rv = queue_.front(); + queue_.pop_front(); + if (queue_.empty()) + has_elmt_ = false; + return false; + } + } + { + std::unique_lock lock(mutex_); + while (queue_.empty() && !exit_now_) { + cv_.wait(lock); + } + if (!exit_now_) { + rv = queue_.front(); + queue_.pop_front(); + if (queue_.empty()) + has_elmt_ = false; + return false; + } else { + return true; + } + } + } + std::list PopAll() { + std::lock_guard lock(mutex_); + std::list rv; + rv.swap(queue_); + return rv; + } + // Call `SignalForKill` before destruction + void SignalForKill() { + std::unique_lock lock(mutex_); + exit_now_ = true; + cv_.notify_all(); + } + size_t QueueSize() { + std::unique_lock lock(mutex_); + return queue_.size(); + } + + private: + std::atomic has_elmt_; + std::list queue_; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic exit_now_; + + ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; + ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; +}; diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h new file mode 100644 index 000000000000..5a0cc3f786e6 --- /dev/null +++ b/src/common/spin_lock.h @@ -0,0 +1,45 @@ +#ifndef _SPINLOCK_XCHG_H +#define _SPINLOCK_XCHG_H + +/* Spin lock using xchg. + * Copied from http://locklessinc.com/articles/locks/ + */ + +/* Compile read-write barrier */ +#define barrier() asm volatile("": : :"memory") + +/* Pause instruction to prevent excess processor bus usage */ +#define cpu_relax() asm volatile("pause\n": : :"memory") + +static inline unsigned short xchg_8(void *ptr, unsigned char x) { + __asm__ __volatile__("xchgb %0,%1" + :"=r" (x) + :"m" (*(volatile unsigned char *)ptr), "0" (x) + :"memory"); + + return x; +} + +#define BUSY 1 +typedef unsigned char spinlock; + +#define SPINLOCK_INITIALIZER 0 + +static inline void spin_lock(spinlock *lock) { + while (1) { + if (!xchg_8(lock, BUSY)) return; + + while (*lock) cpu_relax(); + } +} + +static inline void spin_unlock(spinlock *lock) { + barrier(); + *lock = 0; +} + +static inline int spin_trylock(spinlock *lock) { + return xchg_8(lock, BUSY); +} + +#endif /* _SPINLOCK_XCHG_H */ diff --git a/src/dag_engine/threaded_engine.cc b/src/dag_engine/threaded_engine.cc new file mode 100644 index 000000000000..143b5e72f413 --- /dev/null +++ b/src/dag_engine/threaded_engine.cc @@ -0,0 +1,179 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "../common/spin_lock.h" +#include "../common/concurrent_blocking_queue.h" + +using namespace std; + +namespace mxnet { + +#define DEFAULT_NUM_WORKER_THREADS 4 + +class ThreadedEngine : public DAGEngine { + public: + ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { + for(int i = 0; i < numthreads; ++i) { + worker_queues_.push_back(new ConcurrentBlockingQueue()); + workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); + } + } + ~ThreadedEngine() { + for(int i = 0; i < numthreads_; ++i) { + worker_queues_[i]->SignalForKill(); + delete worker_queues_[i]; + workers_[i].join(); + } + } + void Push(AsyncOp exec_fun, + Context exec_ctx, + const vector &use_vars, + const vector &mutate_vars) override { + shared_ptr opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, + [this] (OpDescr* o) { this->OnDepsResolved(o); } ); + for( Variable v : use_vars ) { // read + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + if (vard->rw < 0) { + vard->waitings.push(make_pair(opd, DepType::kRead)); + } else { + ++vard->rw; + } + spin_unlock(&vard->lock); + } + for( Variable v : mutate_vars ) { // write + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + if (vard->rw != 0) { + vard->waitings.push(make_pair(opd, DepType::kWrite)); + } else { + vard->rw = -1; + } + spin_unlock(&vard->lock); + } + } + void Push(Op exec_fun, + Context exec_ctx, + const vector &use_vars, + const vector &mutate_vars) override { + this->Push([exec_fun](RunContext ctx, Callback on_complete) { + exec_fun(ctx); on_complete(); + }, exec_ctx, use_vars, mutate_vars); + } + void PushDelete(Op delete_fun, Variable var) override { + // TODO + this->Push([delete_fun, var] (RunContext ctx) { + delete_fun(ctx); + delete static_cast(var); + }, Context()/* TODO exec_ctx is missing?*/, {}, {var}); + } + Variable NewVar() override { + // in practice return a ptr to a cell + // that have the info about the variable + // use ptr directly instead of ID because this avoids an indirect mapping + VarDescr* vd = new VarDescr; + vd->lock = SPINLOCK_INITIALIZER; + vd->rw = 0; + return vd; + } + void WaitForVar(Variable var) override { + // TODO + } + void WaitForAll() override { + // TODO + } + private: + enum class DepType { + kRead = 0, + kWrite, + kDelete, + }; + struct OpDescr { + AsyncOp op; + Context exec_ctx; + vector read_vars; + vector write_vars; + }; + struct VarDescr { + spinlock lock; + int rw; // a semaphore-like count + // if rw > 0, the variable has several readers and the number + // means how many operators are currently reading it; + // if rw < 0, the varaible has one writer (should be -1) + queue, DepType>> waitings; + }; + void TriggerWaiting(VarDescr* vard) { + // ATTENTION: this function should be called with vard->lock held. + CHECK(vard->rw == 0) << "the variable should be free during triggering"; + if(!vard->waitings.empty()) { + // pop all reads first + while(vard->waitings.front().second == DepType::kRead) { + vard->waitings.pop(); + ++vard->rw; + } + if (vard->rw == 0) { + // if the next one is a delete + // pop the next write + vard->waitings.pop(); + vard->rw = -1; + } + } + } + void OnOpFinished(OpDescr* opd) { + CHECK(opd) << "completing a nullptr op!"; + for(Variable v : opd->read_vars) { + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; + if(--vard->rw == 0) { + TriggerWaiting(vard); + } + spin_unlock(&vard->lock); + } + for(Variable v : opd->write_vars) { + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; + vard->rw = 0; + TriggerWaiting(vard); + spin_unlock(&vard->lock); + } + delete opd; // delete the operator + } + RunContext GetRunContext(const Context& ctx) { + // TODO + return RunContext(); + } + void OnDepsResolved(OpDescr* opd) { + static default_random_engine generator; + static uniform_int_distribution distribution(0, numthreads_); + int thrid = distribution(generator); + worker_queues_[thrid]->Push(opd); + } + void WorkerRoutine(int thrid) { + OpDescr* opd = nullptr; + while(! worker_queues_[thrid]->Pop(opd)) { + LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; + opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); + opd = nullptr; + } + } + private: + const int numthreads_; + vector*> worker_queues_; + vector workers_; +}; + +// implements the singleton factory +DAGEngine* DAGEngine::Get() { + static ThreadedEngine engine; + return &engine; +} +} // namespace mxnet diff --git a/test/test_threaded_engine.cc b/test/test_threaded_engine.cc new file mode 100644 index 000000000000..40dea029cf6e --- /dev/null +++ b/test/test_threaded_engine.cc @@ -0,0 +1,9 @@ +#include + +using namespace std; +using namespace mxnet; + +int main() { + DAGEngine* engine = DAGEngine::Get(); + return 0; +} From 59a7ceec0b251ea76a580b061b50122590c0fc8c Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 30 Jun 2015 00:45:58 -0600 Subject: [PATCH 10/11] refactor api --- Makefile | 3 +- api/mxnet_api.cc | 10 +-- api/mxnet_api.h | 10 +-- api/python/mxnet/__init__.py | 7 +- api/python/mxnet/base.py | 43 ------------ api/python/mxnet/function.py | 128 +++++++++++++++++++++++++++++++++++ api/python/mxnet/narray.py | 29 +++++--- api/python/test_python.py | 3 + include/mxnet/api_registry.h | 33 ++++++++- src/narray/narray.cc | 7 +- 10 files changed, 210 insertions(+), 63 deletions(-) create mode 100644 api/python/mxnet/function.py diff --git a/Makefile b/Makefile index e037062bd9fa..8892255f4386 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,8 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o threaded_engine.o +# add threaded engine after it is done +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o engine.o CUOBJ = narray_op_gpu.o operator_gpu.o SLIB = api/libmxnet.so ALIB = api/libmxnet.a diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 3fefa4473570..5ca45257f786 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -164,15 +164,17 @@ int MXFuncGetName(FunctionHandle fun, API_END(); } -int MXFuncDescribeArgs(FunctionHandle fun, - mx_uint *num_use_vars, - mx_uint *num_scalars, - mx_uint *num_mutate_vars) { +int MXFuncDescribe(FunctionHandle fun, + mx_uint *num_use_vars, + mx_uint *num_scalars, + mx_uint *num_mutate_vars, + int *type_mask) { API_BEGIN(); auto *f = static_cast(fun); *num_use_vars = f->num_use_vars; *num_scalars = f->num_scalars; *num_mutate_vars = f->num_mutate_vars; + *type_mask = f->type_mask; API_END(); } diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 517de39f180a..0710f9cee37a 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -173,13 +173,15 @@ MXNET_DLL int MXFuncGetName(FunctionHandle fun, * \param num_use_vars how many NArrays to be passed in as used_vars * \param num_scalars scalar variable is needed * \param num_mutate_vars how many NArrays to be passed in as mutate_vars + * \param type_mask the type mask of this function * \return 0 when success, -1 when failure happens * \sa MXFuncInvoke */ -MXNET_DLL int MXFuncDescribeArgs(FunctionHandle fun, - mx_uint *num_use_vars, - mx_uint *num_scalars, - mx_uint *num_mutate_vars); +MXNET_DLL int MXFuncDescribe(FunctionHandle fun, + mx_uint *num_use_vars, + mx_uint *num_scalars, + mx_uint *num_mutate_vars, + int *type_mask); /*! * \brief invoke a function, the array size of passed in arguments diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py index 36f7732167a4..c78c0d485159 100644 --- a/api/python/mxnet/__init__.py +++ b/api/python/mxnet/__init__.py @@ -8,5 +8,10 @@ Version : 0.10 """ from __future__ import absolute_import -from .narray import NArray + from .context import Context, current_context +from .narray import NArray, _init_function_registry +from .function import _FunctionRegistry + +# this is a global function registry that can be used to invoke functions +op = _init_function_registry(_FunctionRegistry()) diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index 581778ce6fb8..441c67fd092d 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -133,46 +133,3 @@ def ctypes2numpy_shared(cptr, shape): dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) return np.frombuffer(dbuffer, dtype = np.float32).reshape(shape) -#------------------------------ -# get list of functon pointers -#------------------------------ -class _FunctionRegistry: - def __init__(self): - plist = ctypes.POINTER(ctypes.c_void_p)() - size = ctypes.c_uint() - check_call(lib.MXListFunctions(ctypes.byref(size), - ctypes.byref(plist))) - hmap = {} - for i in range(size.value): - h = plist[i] - name = ctypes.c_char_p() - check_call(lib.MXFuncGetName(h, ctypes.byref(name))) - hmap[name.value] = h - self.__dict__.update(hmap) - -# handle to function registry -op = _FunctionRegistry() - - -def invoke(fhandle, used_vars, scalars, mutate_vars): - """Invoke a function handle by passing in arguments as tuples - - Parameters - ---------- - fhandle : FunctionHandle - function handle of C API - - used_vars : tuple - tuple of NArray handles - - scalars : tuple - tuple of real number arguments - - mutate_vars : tuple - tuple of NArray handles to mutate - """ - check_call(lib.MXFuncInvoke( - fhandle, - c_array(NArrayHandle, used_vars), - c_array(mx_float, scalars), - c_array(NArrayHandle, mutate_vars))) diff --git a/api/python/mxnet/function.py b/api/python/mxnet/function.py new file mode 100644 index 000000000000..f6c05a25c7d2 --- /dev/null +++ b/api/python/mxnet/function.py @@ -0,0 +1,128 @@ +# coding: utf-8 +"""NArray functions support of mxnet""" +from __future__ import absolute_import + +import ctypes +from .base import lib +from .base import c_array +from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle +from .base import check_call +from .narray import NArray, _new_empty_handle + +class _Function: + # constants for type masks + NARRAY_ARG_BEFORE_SCALAR = 1 + SCALAR_ARG_BEFORE_NARRAY = 2 + ACCEPT_EMPTY_MUTATE_TARGET = 3 + + def __init__(self, handle, name): + """Initialize the function with handle + + Parameters + ---------- + handle : FunctionHandle + the function handle of the function + + name : string + the name of the function + """ + self.handle = handle + n_used_vars = mx_uint() + n_scalars = mx_uint() + n_mutate_vars = mx_uint() + type_mask = ctypes.c_int() + check_call(lib.MXFuncDescribe( + self.handle, + ctypes.byref(n_used_vars), + ctypes.byref(n_scalars), + ctypes.byref(n_mutate_vars), + ctypes.byref(type_mask))) + self.n_used_vars = n_used_vars.value + self.n_scalars = n_scalars.value + self.n_mutate_vars = n_mutate_vars.value + self.type_mask = type_mask.value + # infer type of the function + if (self.type_mask & _Function.NARRAY_ARG_BEFORE_SCALAR) != 0: + self.use_vars_range = range(0, self.n_used_vars) + self.scalar_range = range(self.n_used_vars, + self.n_used_vars + self.n_scalars) + else: + self.scalar_range = range(0, self.n_scalars) + self.use_vars_range = range(self.n_scalars, + self.n_scalars + self.n_used_vars) + self.accept_empty_mutate = (self.type_mask & + _Function.ACCEPT_EMPTY_MUTATE_TARGET) != 0 + + def __call__(self, *args, **kwargs): + """Invoke this function by passing in parameters + + Parameters + ---------- + *args: positional arguments + positional arguments of input scalars and NArray + + mutate_vars: kwarg(optional) + provide the NArray to store the result of the operation + + Returns + ------- + the result NArrays of mutated result + """ + if 'mutate_vars' in kwargs: + mutate_vars = kwargs['mutate_vars'] + if len(mutate_vars) != self.n_mutate_vars: + raise MXNetError('expect %d mutate_vars in function %s', self.n_mutate_vars, self.name) + else: + if self.accept_empty_mutate: + mutate_vars = tuple( + NArray(_new_empty_handle()) for i in range(self.n_mutate_vars)) + else: + raise MXNetError('mutate_vars argument is required to call this function') + + self.invoke_with_handle_([args[i].handle for i in self.use_vars_range], + [args[i] for i in self.scalar_range], + [v.handle for v in mutate_vars]) + if self.n_mutate_vars == 1: + return mutate_vars[0] + else: + return mutate_vars + + def invoke_with_handle_(self, use_vars, scalars, mutate_vars): + """Invoke this function by passing in arguments as tuples + + This is a very primitive call to the function handle that + involves passing in a C handle + + Parameters + ---------- + fhandle : FunctionHandle + function handle of C API + + use_vars : tuple + tuple of NArray handles + + scalars : tuple + tuple of real number arguments + + mutate_vars : tuple + tuple of NArray handles to mutate + """ + check_call(lib.MXFuncInvoke( + self.handle, + c_array(NArrayHandle, use_vars), + c_array(mx_float, scalars), + c_array(NArrayHandle, mutate_vars))) + +class _FunctionRegistry: + def __init__(self): + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + check_call(lib.MXListFunctions(ctypes.byref(size), + ctypes.byref(plist))) + hmap = {} + for i in range(size.value): + h = plist[i] + name = ctypes.c_char_p() + check_call(lib.MXFuncGetName(h, ctypes.byref(name))) + hmap[name.value] = _Function(h, name.value) + self.__dict__.update(hmap) diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py index eca23f2ad3c1..9149921266e9 100644 --- a/api/python/mxnet/narray.py +++ b/api/python/mxnet/narray.py @@ -5,15 +5,15 @@ import ctypes import numpy as np from .base import lib -from .base import op from .base import c_array from .base import mx_uint, mx_float, NArrayHandle from .base import ctypes2numpy_shared -from .base import invoke from .base import check_call from .base import MXNetError from .context import Context +global op + def _new_empty_handle(): """Return a new empty handle @@ -67,7 +67,7 @@ def __del__(self): def __add__(self, other): hret = _new_empty_handle() if isinstance(other, NArray): - invoke(op.plus, (other.handle, self.handle), (), (hret,)) + op.plus.invoke_with_handle_((other.handle, self.handle), (), (hret,)) else: raise MXNetError('type %s not supported' % str(type(other))) return NArray(handle = hret) @@ -78,7 +78,7 @@ def __radd__(self, other): def __sub__(self, other): hret = _new_empty_handle() if isinstance(other, NArray): - invoke(op.minus, (other.handle, self.handle), (), (hret,)) + op.minus.invoke_with_handle_((other.handle, self.handle), (), (hret,)) else: raise MXNetError('type %s not supported' % str(type(other))) return NArray(handle = hret) @@ -86,7 +86,7 @@ def __sub__(self, other): def __mul__(self, other): hret = _new_empty_handle() if isinstance(other, NArray): - invoke(op.mul, (other.handle, self.handle), (), (hret,)) + op.mul.invoke_with_handle_((other.handle, self.handle), (), (hret,)) else: raise MXNetError('type %s not supported' % str(type(other))) return NArray(handle = hret) @@ -97,7 +97,7 @@ def __rmul__(self, other): def __div__(self, other): hret = _new_empty_handle() if isinstance(other, NArray): - invoke(op.div, (other.handle, self.handle), (), (hret,)) + op.div.invoke_with_handle_((other.handle, self.handle), (), (hret,)) else: raise MXNetError('type %s not supported' % str(type(other))) return NArray(handle = hret) @@ -167,11 +167,11 @@ def copyto(self, other): the copy target NArray """ if isinstance(other, NArray): - invoke(op.copy, (self.handle,), (), (other.handle,)) + op.copy.invoke_with_handle_((self.handle,), (), (other.handle,)) return other elif isinstance(other, Context): hret = _new_alloc_handle(self.shape, other, True) - invoke(op.copy, (self.handle,), (), (hret,)) + op.copy.invoke_with_handle_((self.handle,), (), (hret,)) return NArray(handle = hret) else: raise MXNetError('copyto do not support type ' + type(other)) @@ -190,3 +190,16 @@ def create(shape, ctx = Context.default_ctx): """ return NArray(handle = _new_alloc_handle(shape, ctx, False)) +def _init_function_registry(new_op): + """Initialize the global variable op with new_op + + This function is used to resolve cyclic dependency of .narray on function + + Parameters + ---------- + new_op : function._FunctionRegistry + a FunctionRegistry to pass in in startup + """ + global op + op = new_op + return op diff --git a/api/python/test_python.py b/api/python/test_python.py index f9e8003c6e06..90154f020945 100644 --- a/api/python/test_python.py +++ b/api/python/test_python.py @@ -8,7 +8,10 @@ c = b * a +cc = mx.op.mul(b, a) + print(c.context) +print(cc.numpy) d = c.copyto(mx.Context('cpu', 0)) diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 253eef1574a1..e08ab03547fb 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -17,13 +17,31 @@ #include "./narray.h" namespace mxnet { + +/*! \brief mask information on how functions can be exposed */ +enum FunctionTypeMask { + /*! \brief all the use_vars should go before scalar */ + kNArrayArgBeforeScalar = 1, + /*! \brief all the scalar should go before use_vars */ + kScalarArgBeforeNArray = 1 << 1, + /*! + * \brief whether this function allows the handles in the target to + * be empty NArray that are not yet initialized, and will initialize + * them when the function is invoked. + * + * most function should support this, except copy between different + * devices, which requires the NArray to be pre-initialized with context + */ + kAcceptEmptyMutateTarget = 1 << 2 +}; + /*! \brief registry of NArray functions */ class FunctionRegistry { public: /*! \brief definition of NArray function */ typedef std::function Function; + NArray **mutate_vars)> Function; /*! \brief registry entry */ struct Entry { /*! \brief function name */ @@ -34,6 +52,8 @@ class FunctionRegistry { unsigned num_mutate_vars; /*! \brief number of scalars used by this function */ unsigned num_scalars; + /*! \brief information on how function should be called from API */ + int type_mask; /*! \brief the real function */ Function body; /*! @@ -45,6 +65,7 @@ class FunctionRegistry { num_use_vars(0), num_mutate_vars(0), num_scalars(0), + type_mask(0), body(nullptr) {} /*! * \brief set the number of mutate variables @@ -78,6 +99,14 @@ class FunctionRegistry { inline Entry &set_body(Function f) { body = f; return *this; } + /*! + * \brief set the function body + * \param f function body to set + * \return ref to the registered entry, used to set properties + */ + inline Entry &set_type_mask(int tmask) { + type_mask = tmask; return *this; + } /*! * \brief set the function body to a binary NArray function * this will also auto set the parameters correctly @@ -92,6 +121,7 @@ class FunctionRegistry { fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); }; num_use_vars = 2; num_mutate_vars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; return *this; } /*! @@ -107,6 +137,7 @@ class FunctionRegistry { funary(*used_vars[0], mutate_vars[0]); }; num_use_vars = 1; num_mutate_vars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; return *this; } /*! diff --git a/src/narray/narray.cc b/src/narray/narray.cc index 8ec3c25f8979..e03a2c374190 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -144,5 +144,10 @@ REGISTER_NARRAY_FUN(minus).set_function(BinaryOp); REGISTER_NARRAY_FUN(mul).set_function(BinaryOp); REGISTER_NARRAY_FUN(div).set_function(BinaryOp); -REGISTER_NARRAY_FUN(copy).set_function(CopyFromTo); +// copy function is special +//that we need to remove kAcceptEmptyMutateTarget from it +REGISTER_NARRAY_FUN(copy) +.set_function(CopyFromTo) +.set_type_mask(kNArrayArgBeforeScalar); + } // namespace mxnet From 480604fde741cc8d17fc5377d28069661cea875c Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 30 Jun 2015 00:58:19 -0600 Subject: [PATCH 11/11] minor fix --- api/python/mxnet/function.py | 15 +++++++++------ api/python/mxnet/narray.py | 2 ++ api/python/test_python.py | 1 - 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/api/python/mxnet/function.py b/api/python/mxnet/function.py index f6c05a25c7d2..149d88a0f450 100644 --- a/api/python/mxnet/function.py +++ b/api/python/mxnet/function.py @@ -6,14 +6,14 @@ from .base import lib from .base import c_array from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle -from .base import check_call +from .base import check_call, MXNetError from .narray import NArray, _new_empty_handle class _Function: # constants for type masks NARRAY_ARG_BEFORE_SCALAR = 1 - SCALAR_ARG_BEFORE_NARRAY = 2 - ACCEPT_EMPTY_MUTATE_TARGET = 3 + SCALAR_ARG_BEFORE_NARRAY = 1 << 1 + ACCEPT_EMPTY_MUTATE_TARGET = 1 << 2 def __init__(self, handle, name): """Initialize the function with handle @@ -27,6 +27,7 @@ def __init__(self, handle, name): the name of the function """ self.handle = handle + self.name = name n_used_vars = mx_uint() n_scalars = mx_uint() n_mutate_vars = mx_uint() @@ -70,14 +71,16 @@ def __call__(self, *args, **kwargs): """ if 'mutate_vars' in kwargs: mutate_vars = kwargs['mutate_vars'] + if isinstance(mutate_vars, NArray): + mutate_vars = (mutate_vars,) if len(mutate_vars) != self.n_mutate_vars: - raise MXNetError('expect %d mutate_vars in function %s', self.n_mutate_vars, self.name) + raise MXNetError('expect %d mutate_vars in op.%s', self.n_mutate_vars, self.name) else: if self.accept_empty_mutate: mutate_vars = tuple( NArray(_new_empty_handle()) for i in range(self.n_mutate_vars)) - else: - raise MXNetError('mutate_vars argument is required to call this function') + else: + raise MXNetError('mutate_vars argument is required to call op.%s' % self.name) self.invoke_with_handle_([args[i].handle for i in self.use_vars_range], [args[i] for i in self.scalar_range], diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py index 9149921266e9..b16270cdba61 100644 --- a/api/python/mxnet/narray.py +++ b/api/python/mxnet/narray.py @@ -12,6 +12,8 @@ from .base import MXNetError from .context import Context +# op is implicitly imported from .function +# as a singleton of _FunctionRegistry global op def _new_empty_handle(): diff --git a/api/python/test_python.py b/api/python/test_python.py index 90154f020945..4edab1247e1a 100644 --- a/api/python/test_python.py +++ b/api/python/test_python.py @@ -12,7 +12,6 @@ print(c.context) print(cc.numpy) - d = c.copyto(mx.Context('cpu', 0)) print(d.numpy)