From 5d9da0f19bc74958b17f76e88b4465b193a5811f Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 6 Oct 2015 23:08:25 -0700 Subject: [PATCH] [R] allow operator overloading --- R-package/R/base.R | 9 ++++----- R-package/R/ndarray.R | 11 +++++++++-- R-package/demo/basic_ndarray.R | 9 ++++++--- R-package/src/ndarray.cc | 1 + R-package/src/ndarray.h | 23 ++++++++++++----------- 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/R-package/R/base.R b/R-package/R/base.R index 5fb8f501c82d..6b20f63530cc 100644 --- a/R-package/R/base.R +++ b/R-package/R/base.R @@ -1,7 +1,6 @@ require(methods) -loadModule("mxnet", TRUE) -setOldClass("mx.NDArray") -setMethod("+", signature(e1="mx.NDArray", e2="numeric"), function(e1, e2) { - mx.nd.internal.plus.scalar(e1, e2) -}) +.onLoad <- function(libname, pkgname) { + loadModule("mxnet", TRUE) + init.ndarray.methods() +} diff --git a/R-package/R/ndarray.R b/R-package/R/ndarray.R index 4bd00b817e0c..b6c744030b91 100644 --- a/R-package/R/ndarray.R +++ b/R-package/R/ndarray.R @@ -1,5 +1,12 @@ #' NDArray #' #' Additional NDArray related operations -require(methods) - +init.ndarray.methods <-function() { + require(methods) + setMethod("+", signature(e1="Rcpp_MXNDArray", e2="numeric"), function(e1, e2) { + mx.nd.internal.plus.scalar(e1, e2) + }) + setMethod("+", signature(e1="Rcpp_MXNDArray", e2="Rcpp_MXNDArray"), function(e1, e2) { + mx.nd.internal.plus(e1, e2) + }) +} diff --git a/R-package/demo/basic_ndarray.R b/R-package/demo/basic_ndarray.R index a22e4fc73a41..21ea085a36e5 100644 --- a/R-package/demo/basic_ndarray.R +++ b/R-package/demo/basic_ndarray.R @@ -1,10 +1,13 @@ require(mxnet) require(methods) -x = as.array(c(1,2,3)) + +x = as.array(c(1,2,3)) mat = mx.nd.array(x, mx.cpu(0)) -mat = mx.nd.internal.plus(mat, mat) +mat = mat + 1.0 +mat = mat + mat + xx = mx.nd.internal.as.array(mat) -print(class(mat)) + print(xx) diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index c13fd70657a8..f65de71f5be8 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -219,6 +219,7 @@ NDArray::RObjectType NDArray::Array( // register normal function. void NDArray::InitRcppModule() { using namespace Rcpp; // NOLINT(*) + class_("MXNDArray"); function("mx.nd.load", &NDArray::Load); function("mx.nd.save", &NDArray::Save); function("mx.nd.array", &NDArray::Array); diff --git a/R-package/src/ndarray.h b/R-package/src/ndarray.h index 5574fe2f5bd4..06f995339dd9 100644 --- a/R-package/src/ndarray.h +++ b/R-package/src/ndarray.h @@ -22,7 +22,7 @@ class NDArrayFunction; class NDArray { public: /*! \brief The type of NDArray in R's side */ - typedef Rcpp::List RObjectType; + typedef Rcpp::RObject RObjectType; /*! \return convert the NDArray to R's Array */ Rcpp::NumericVector AsNumericVector() const; /*! \return The shape of the array */ @@ -49,7 +49,7 @@ class NDArray { * \param obj The R NDArray object * \return The external pointer to the object */ - inline static Rcpp::XPtr XPtr(const Rcpp::RObject& obj); + inline static NDArray* XPtr(const Rcpp::RObject& obj); /*! * \brief Load a list of ndarray from the file. * \param filename the name of the file. @@ -170,6 +170,8 @@ class NDArrayFunction : public ::Rcpp::CppFunction { } // namespace mxnet +RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray); + namespace mxnet { namespace R { // implementatins of inline functions @@ -187,22 +189,21 @@ inline Rcpp::Dimension NDArray::shape() const { inline NDArray::RObjectType NDArray::RObject( NDArrayHandle handle, bool writable) { - Rcpp::List ret = Rcpp::List::create( - Rcpp::Named("ptr") = Rcpp::XPtr(new NDArray(handle, writable))); - ret.attr("class") = "mx.NDArray"; - return ret; + NDArray* p = new NDArray(handle, writable); + // TODO(KK) can we avoid use internal::make_new_object? + // The Wrap function requires NDArray object instead of ptr, + // Which will trigger destructor + return Rcpp::internal::make_new_object(p); } inline NDArray::RObjectType NDArray::Move(const Rcpp::RObject& src) { - Rcpp::XPtr old = NDArray::XPtr(src); + NDArray* old = NDArray::XPtr(src); old->moved_ = true; return NDArray::RObject(old->handle_, old->writable_); } -inline Rcpp::XPtr NDArray::XPtr(const Rcpp::RObject& obj) { - Rcpp::List ret(obj); - Rcpp::RObject xptr = ret[0]; - Rcpp::XPtr ptr(xptr); +inline NDArray* NDArray::XPtr(const Rcpp::RObject& obj) { + NDArray* ptr = Rcpp::as(obj); RCHECK(!ptr->moved_) << "Passed in a moved NDArray as parameters." << " Moved parameters should no longer be used";