From 2f4fb6d6380aa3c577f76aa800964596609fd162 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 7 Oct 2015 12:06:23 -0700 Subject: [PATCH] [R] enable formals and move --- R-package/demo/basic_ndarray.R | 6 ++- R-package/src/Makevars | 2 +- R-package/src/ndarray.cc | 80 +++++++++++++++++++++++++++++----- R-package/src/ndarray.h | 18 ++++---- 4 files changed, 84 insertions(+), 22 deletions(-) diff --git a/R-package/demo/basic_ndarray.R b/R-package/demo/basic_ndarray.R index 21ea085a36e5..12ad27d543fd 100644 --- a/R-package/demo/basic_ndarray.R +++ b/R-package/demo/basic_ndarray.R @@ -6,8 +6,12 @@ x = as.array(c(1,2,3)) mat = mx.nd.array(x, mx.cpu(0)) mat = mat + 1.0 mat = mat + mat +oldmat = mat +mat = mx.nd.internal.plus.scalar(mat, 1, out=mat) +xx = mat$as.array() -xx = mx.nd.internal.as.array(mat) +# This will result in an error, becase mat has been moved +oldmat + 1 print(xx) diff --git a/R-package/src/Makevars b/R-package/src/Makevars index e4c54ace37ed..1dede45a1e9f 100644 --- a/R-package/src/Makevars +++ b/R-package/src/Makevars @@ -16,4 +16,4 @@ mxlib: cp $(PKGROOT)/lib/libmxnet.so ../inst/libs/libmxnet.so PKG_CPPFLAGS = -I$(PKGROOT)/include -I$(PKGROOT)/dmlc-core/include -PKG_LIBS = $(LINKMXNET) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) +PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index 83a4039515b1..b894937de231 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -218,11 +218,11 @@ NDArray::RObjectType NDArray::Array( // register normal function. void NDArray::InitRcppModule() { using namespace Rcpp; // NOLINT(*) - class_("MXNDArray"); + class_("MXNDArray") + .method("as.array", &NDArray::AsNumericVector); function("mx.nd.load", &NDArray::Load); function("mx.nd.save", &NDArray::Save); function("mx.nd.array", &NDArray::Array); - function("mx.nd.internal.as.array", &NDArray::AsRArray); } NDArrayFunction::NDArrayFunction(FunctionHandle handle) @@ -266,6 +266,7 @@ NDArrayFunction::NDArrayFunction(FunctionHandle handle) const int kNDArrayArgBeforeScalar = 1; const int kAcceptEmptyMutateTarget = 1 << 2; int type_mask; + MX_CALL(MXFuncDescribe( handle, &num_use_vars_, &num_scalars_, &num_mutate_vars_, &type_mask)); @@ -276,32 +277,89 @@ NDArrayFunction::NDArrayFunction(FunctionHandle handle) begin_scalars_ = num_scalars_; begin_scalars_ = 0; } - num_args_ = num_use_vars_ + num_scalars_; - accept_empty_out_ = ((type_mask & kAcceptEmptyMutateTarget) != 0); + begin_mutate_vars_ = num_use_vars_ + num_scalars_; + num_args_ = num_use_vars_ + num_scalars_ + num_mutate_vars_; + accept_empty_out_ = ((type_mask & kAcceptEmptyMutateTarget) != 0) && num_mutate_vars_ == 1; + } + + // construct formals + { + Rcpp::List arg_values(num_args_); + std::vector arg_names(num_args_); + for (mx_uint i = 0; i < num_use_vars_; ++i) { + std::ostringstream os; + os << "X" << (i + 1); + arg_names[begin_use_vars_ + i] = os.str(); + // TODO(KK) this should really be not specified + arg_values[begin_use_vars_ + i] = R_NilValue; + } + for (mx_uint i = 0; i < num_scalars_; ++i) { + std::ostringstream os; + os << "s" << (i + 1); + arg_names[begin_scalars_ + i] = os.str(); + // TODO(KK) this should really be not specified + arg_values[begin_scalars_ + i] = R_NilValue; + } + if (accept_empty_out_) { + arg_names[begin_mutate_vars_] = "out"; + // this is really optional + arg_values[begin_mutate_vars_] = R_NilValue; + } else { + for (mx_uint i = 0; i < num_mutate_vars_; ++i) { + std::ostringstream os; + os << "out" << (i + 1); + arg_names[begin_mutate_vars_ + i] = os.str(); + // TODO(KK) this should really be not specified, not optional + arg_values[begin_mutate_vars_ + i] = R_NilValue; + } + } + formals_ = arg_values; + formals_.attr("names") = arg_names; } } SEXP NDArrayFunction::operator() (SEXP* args) { BEGIN_RCPP; - RCHECK(accept_empty_out_) - << "not yet support mutate target"; - NDArrayHandle ohandle; - MX_CALL(MXNDArrayCreateNone(&ohandle)); + std::vector scalars(num_scalars_); std::vector use_vars(num_use_vars_); for (mx_uint i = 0; i < num_scalars_; ++i) { - // better to use Rcpp cast? + // TODO(KK) better to use Rcpp cast? scalars[i] = (REAL)(args[begin_scalars_ + i])[0]; } for (mx_uint i = 0; i < num_use_vars_; ++i) { use_vars[i] = NDArray::XPtr(args[begin_use_vars_ + i])->handle_; } + + std::vector mutate_vars(num_mutate_vars_); + Rcpp::List out(num_mutate_vars_); + for (mx_uint i = 0; i < num_mutate_vars_; ++i) { + // TODO(KK) Rcpp way of checking null? + if (args[begin_mutate_vars_ + i] == R_NilValue) { + if (accept_empty_out_) { + NDArrayHandle ohandle; + MX_CALL(MXNDArrayCreateNone(&ohandle)); + out[i] = NDArray::RObject(ohandle); + } else { + RLOG_FATAL << "Parameter out need to be specified"; + } + } else { + // move the old parameters, these are no longer valid + out[i] = NDArray::Move(args[begin_mutate_vars_ + i]); + } + mutate_vars[i] = NDArray::XPtr(out[i])->handle_; + } + MX_CALL(MXFuncInvoke(handle_, dmlc::BeginPtr(use_vars), dmlc::BeginPtr(scalars), - &ohandle)); - return NDArray::RObject(ohandle); + dmlc::BeginPtr(mutate_vars))); + if (num_mutate_vars_ == 1) { + return out[0]; + } else { + return out; + } END_RCPP; } diff --git a/R-package/src/ndarray.h b/R-package/src/ndarray.h index 06f995339dd9..21371f8ba57b 100644 --- a/R-package/src/ndarray.h +++ b/R-package/src/ndarray.h @@ -77,14 +77,6 @@ class NDArray { */ static RObjectType Array(const Rcpp::RObject& src, const Context::RObjectType& ctx); - /*! - * \brief Convert the NDArray to R's Array - * \param src the source MX.NDArray - * \return the converted array - */ - inline static Rcpp::NumericVector AsRArray(const RObjectType& src) { - return XPtr(src)->AsNumericVector(); - } /*! \brief static function to initialize the Rcpp functions */ static void InitRcppModule(); /*! \brief destructor */ @@ -137,6 +129,10 @@ class NDArrayFunction : public ::Rcpp::CppFunction { return name_.c_str(); } + virtual SEXP get_formals() { + return formals_; + } + virtual DL_FUNC get_function_ptr() { return (DL_FUNC)NULL; // NOLINT(*) } @@ -159,12 +155,16 @@ class NDArrayFunction : public ::Rcpp::CppFunction { mx_uint begin_scalars_; // number of scalars mx_uint num_scalars_; + // begining of mutate variables + mx_uint begin_mutate_vars_; // number of mutate variables mx_uint num_mutate_vars_; // number of arguments mx_uint num_args_; // whether it accept empty output bool accept_empty_out_; + // ther formals of arguments + Rcpp::List formals_; }; } // namespace R } // namespace mxnet @@ -206,7 +206,7 @@ inline NDArray* NDArray::XPtr(const Rcpp::RObject& obj) { NDArray* ptr = Rcpp::as(obj); RCHECK(!ptr->moved_) << "Passed in a moved NDArray as parameters." - << " Moved parameters should no longer be used"; + << " Moved parameters should no longer be used\n"; return ptr; } } // namespace R