Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[R] enable formals and move #225

Merged
merged 1 commit into from
Oct 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion R-package/demo/basic_ndarray.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ x = as.array(c(1,2,3))
mat = mx.nd.array(x, mx.cpu(0))
mat = mat + 1.0
mat = mat + mat
oldmat = mat
mat = mx.nd.internal.plus.scalar(mat, 1, out=mat)
xx = mat$as.array()

xx = mx.nd.internal.as.array(mat)
# This will result in an error, becase mat has been moved
oldmat + 1

print(xx)

2 changes: 1 addition & 1 deletion R-package/src/Makevars
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ mxlib:
cp $(PKGROOT)/lib/libmxnet.so ../inst/libs/libmxnet.so

PKG_CPPFLAGS = -I$(PKGROOT)/include -I$(PKGROOT)/dmlc-core/include
PKG_LIBS = $(LINKMXNET) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
80 changes: 69 additions & 11 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ NDArray::RObjectType NDArray::Array(
// register normal function.
void NDArray::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
class_<NDArray>("MXNDArray");
class_<NDArray>("MXNDArray")
.method("as.array", &NDArray::AsNumericVector);
function("mx.nd.load", &NDArray::Load);
function("mx.nd.save", &NDArray::Save);
function("mx.nd.array", &NDArray::Array);
function("mx.nd.internal.as.array", &NDArray::AsRArray);
}

NDArrayFunction::NDArrayFunction(FunctionHandle handle)
Expand Down Expand Up @@ -266,6 +266,7 @@ NDArrayFunction::NDArrayFunction(FunctionHandle handle)
const int kNDArrayArgBeforeScalar = 1;
const int kAcceptEmptyMutateTarget = 1 << 2;
int type_mask;

MX_CALL(MXFuncDescribe(
handle, &num_use_vars_, &num_scalars_,
&num_mutate_vars_, &type_mask));
Expand All @@ -276,32 +277,89 @@ NDArrayFunction::NDArrayFunction(FunctionHandle handle)
begin_scalars_ = num_scalars_;
begin_scalars_ = 0;
}
num_args_ = num_use_vars_ + num_scalars_;
accept_empty_out_ = ((type_mask & kAcceptEmptyMutateTarget) != 0);
begin_mutate_vars_ = num_use_vars_ + num_scalars_;
num_args_ = num_use_vars_ + num_scalars_ + num_mutate_vars_;
accept_empty_out_ = ((type_mask & kAcceptEmptyMutateTarget) != 0) && num_mutate_vars_ == 1;
}

// construct formals
{
Rcpp::List arg_values(num_args_);
std::vector<std::string> arg_names(num_args_);
for (mx_uint i = 0; i < num_use_vars_; ++i) {
std::ostringstream os;
os << "X" << (i + 1);
arg_names[begin_use_vars_ + i] = os.str();
// TODO(KK) this should really be not specified
arg_values[begin_use_vars_ + i] = R_NilValue;
}
for (mx_uint i = 0; i < num_scalars_; ++i) {
std::ostringstream os;
os << "s" << (i + 1);
arg_names[begin_scalars_ + i] = os.str();
// TODO(KK) this should really be not specified
arg_values[begin_scalars_ + i] = R_NilValue;
}
if (accept_empty_out_) {
arg_names[begin_mutate_vars_] = "out";
// this is really optional
arg_values[begin_mutate_vars_] = R_NilValue;
} else {
for (mx_uint i = 0; i < num_mutate_vars_; ++i) {
std::ostringstream os;
os << "out" << (i + 1);
arg_names[begin_mutate_vars_ + i] = os.str();
// TODO(KK) this should really be not specified, not optional
arg_values[begin_mutate_vars_ + i] = R_NilValue;
}
}
formals_ = arg_values;
formals_.attr("names") = arg_names;
}
}

SEXP NDArrayFunction::operator() (SEXP* args) {
BEGIN_RCPP;
RCHECK(accept_empty_out_)
<< "not yet support mutate target";
NDArrayHandle ohandle;
MX_CALL(MXNDArrayCreateNone(&ohandle));

std::vector<mx_float> scalars(num_scalars_);
std::vector<NDArrayHandle> use_vars(num_use_vars_);

for (mx_uint i = 0; i < num_scalars_; ++i) {
// better to use Rcpp cast?
// TODO(KK) better to use Rcpp cast?
scalars[i] = (REAL)(args[begin_scalars_ + i])[0];
}
for (mx_uint i = 0; i < num_use_vars_; ++i) {
use_vars[i] = NDArray::XPtr(args[begin_use_vars_ + i])->handle_;
}

std::vector<NDArrayHandle> mutate_vars(num_mutate_vars_);
Rcpp::List out(num_mutate_vars_);
for (mx_uint i = 0; i < num_mutate_vars_; ++i) {
// TODO(KK) Rcpp way of checking null?
if (args[begin_mutate_vars_ + i] == R_NilValue) {
if (accept_empty_out_) {
NDArrayHandle ohandle;
MX_CALL(MXNDArrayCreateNone(&ohandle));
out[i] = NDArray::RObject(ohandle);
} else {
RLOG_FATAL << "Parameter out need to be specified";
}
} else {
// move the old parameters, these are no longer valid
out[i] = NDArray::Move(args[begin_mutate_vars_ + i]);
}
mutate_vars[i] = NDArray::XPtr(out[i])->handle_;
}

MX_CALL(MXFuncInvoke(handle_,
dmlc::BeginPtr(use_vars),
dmlc::BeginPtr(scalars),
&ohandle));
return NDArray::RObject(ohandle);
dmlc::BeginPtr(mutate_vars)));
if (num_mutate_vars_ == 1) {
return out[0];
} else {
return out;
}
END_RCPP;
}

Expand Down
18 changes: 9 additions & 9 deletions R-package/src/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,6 @@ class NDArray {
*/
static RObjectType Array(const Rcpp::RObject& src,
const Context::RObjectType& ctx);
/*!
* \brief Convert the NDArray to R's Array
* \param src the source MX.NDArray
* \return the converted array
*/
inline static Rcpp::NumericVector AsRArray(const RObjectType& src) {
return XPtr(src)->AsNumericVector();
}
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();
/*! \brief destructor */
Expand Down Expand Up @@ -137,6 +129,10 @@ class NDArrayFunction : public ::Rcpp::CppFunction {
return name_.c_str();
}

virtual SEXP get_formals() {
return formals_;
}

virtual DL_FUNC get_function_ptr() {
return (DL_FUNC)NULL; // NOLINT(*)
}
Expand All @@ -159,12 +155,16 @@ class NDArrayFunction : public ::Rcpp::CppFunction {
mx_uint begin_scalars_;
// number of scalars
mx_uint num_scalars_;
// begining of mutate variables
mx_uint begin_mutate_vars_;
// number of mutate variables
mx_uint num_mutate_vars_;
// number of arguments
mx_uint num_args_;
// whether it accept empty output
bool accept_empty_out_;
// ther formals of arguments
Rcpp::List formals_;
};
} // namespace R
} // namespace mxnet
Expand Down Expand Up @@ -206,7 +206,7 @@ inline NDArray* NDArray::XPtr(const Rcpp::RObject& obj) {
NDArray* ptr = Rcpp::as<NDArray*>(obj);
RCHECK(!ptr->moved_)
<< "Passed in a moved NDArray as parameters."
<< " Moved parameters should no longer be used";
<< " Moved parameters should no longer be used\n";
return ptr;
}
} // namespace R
Expand Down