diff --git a/R-package/src/base.h b/R-package/src/base.h index f8eacee20fe4..eb555193ec10 100644 --- a/R-package/src/base.h +++ b/R-package/src/base.h @@ -7,15 +7,20 @@ #define MXNET_RCPP_BASE_H_ #include +#include #include -// to be removed -#include namespace mxnet { namespace R { // change to Rcpp::cerr later, for compatiblity of older version for now -#define RLOG_FATAL LOG(FATAL) +#define RLOG_FATAL ::Rcpp::Rcerr + +// checking macro for R side +#define RCHECK(x) \ + if (!(x)) \ + RLOG_FATAL << "Check " \ + "failed: " #x << ' ' /*! * \brief protected MXNet C API call, report R error if happens. diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index 0d1edcbda035..8cbfff3bfa6e 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -1,5 +1,4 @@ #include -#include #include "./base.h" #include "./ndarray.h" @@ -41,7 +40,7 @@ SEXP NDArray::Load(const std::string& filename) { &out_name_size, &out_names)); Rcpp::List out(out_size); for (mx_uint i = 0; i < out_size; ++i) { - out[i] = Rcpp::XPtr(new NDArray(out_arr[i])); + out[i] = NDArray::RObject(out_arr[i]); } if (out_name_size != 0) { std::vector lst_names(out_size); @@ -105,9 +104,8 @@ NDArrayFunction::NDArrayFunction(FunctionHandle handle) SEXP NDArrayFunction::operator() (SEXP* args) { BEGIN_RCPP; - if (!accept_empty_out_) { - RLOG_FATAL << "not yet support mutate target"; - } + RCHECK(accept_empty_out_) + << "not yet support mutate target"; NDArrayHandle ohandle; MX_CALL(MXNDArrayCreateNone(&ohandle)); std::vector scalars(num_scalars_); @@ -124,7 +122,7 @@ SEXP NDArrayFunction::operator() (SEXP* args) { dmlc::BeginPtr(use_vars), dmlc::BeginPtr(scalars), &ohandle)); - return Rcpp::XPtr(new NDArray(ohandle)); + return NDArray::RObject(ohandle); END_RCPP; } @@ -137,9 +135,8 @@ void NDArray::InitRcppModule() { void NDArrayFunction::InitRcppModule() { Rcpp::Module* scope = ::getCurrentScope(); - if (scope == NULL) { - RLOG_FATAL << "Init Module need to be called inside scope"; - } + RCHECK(scope != NULL) + << "Init Module need to be called inside scope"; mx_uint out_size; FunctionHandle *arr; MX_CALL(MXListFunctions(&out_size, &arr)); diff --git a/R-package/src/ndarray.h b/R-package/src/ndarray.h index 174d0568a316..269d2bdb5e4d 100644 --- a/R-package/src/ndarray.h +++ b/R-package/src/ndarray.h @@ -20,13 +20,17 @@ class NDArray { /*! \brief default constructor */ NDArray() {} /*! - * \brief construct NDArray from handle + * \brief create a R object that correspond to the NDArray * \param handle the NDArrayHandle needed for output. * \param writable Whether the NDArray is writable or not. */ - explicit NDArray(NDArrayHandle handle, - bool writable = true) - : handle_(handle), writable_(writable) {} + static SEXP RObject(NDArrayHandle handle, bool writable = true) { + NDArray *nd = new NDArray(); + nd->handle_ = handle; + nd->writable_ = writable; + // will call destructor after finalize + return Rcpp::XPtr(nd, true); + } /*! * \brief Load a list of ndarray from the file. * \param filename the name of the file. @@ -42,6 +46,12 @@ class NDArray { /*! \brief static function to initialize the Rcpp functions */ static void InitRcppModule(); + /*! \brief destructor */ + ~NDArray() { + // free the handle + MX_CALL(MXNDArrayFree(handle_)); + } + private: // declare friend class friend class NDArrayFunction;