Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[R] allow operator overloading #221

Merged
merged 1 commit into from
Oct 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions R-package/R/base.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
require(methods)

loadModule("mxnet", TRUE)
setOldClass("mx.NDArray")
setMethod("+", signature(e1="mx.NDArray", e2="numeric"), function(e1, e2) {
mx.nd.internal.plus.scalar(e1, e2)
})
.onLoad <- function(libname, pkgname) {
loadModule("mxnet", TRUE)
init.ndarray.methods()
}
11 changes: 9 additions & 2 deletions R-package/R/ndarray.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#' NDArray
#'
#' Additional NDArray related operations
require(methods)

init.ndarray.methods <-function() {
require(methods)
setMethod("+", signature(e1="Rcpp_MXNDArray", e2="numeric"), function(e1, e2) {
mx.nd.internal.plus.scalar(e1, e2)
})
setMethod("+", signature(e1="Rcpp_MXNDArray", e2="Rcpp_MXNDArray"), function(e1, e2) {
mx.nd.internal.plus(e1, e2)
})
}
9 changes: 6 additions & 3 deletions R-package/demo/basic_ndarray.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
require(mxnet)
require(methods)
x = as.array(c(1,2,3))


x = as.array(c(1,2,3))
mat = mx.nd.array(x, mx.cpu(0))
mat = mx.nd.internal.plus(mat, mat)
mat = mat + 1.0
mat = mat + mat

xx = mx.nd.internal.as.array(mat)
print(class(mat))

print(xx)

1 change: 1 addition & 0 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ NDArray::RObjectType NDArray::Array(
// register normal function.
void NDArray::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
class_<NDArray>("MXNDArray");
function("mx.nd.load", &NDArray::Load);
function("mx.nd.save", &NDArray::Save);
function("mx.nd.array", &NDArray::Array);
Expand Down
23 changes: 12 additions & 11 deletions R-package/src/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class NDArrayFunction;
class NDArray {
public:
/*! \brief The type of NDArray in R's side */
typedef Rcpp::List RObjectType;
typedef Rcpp::RObject RObjectType;
/*! \return convert the NDArray to R's Array */
Rcpp::NumericVector AsNumericVector() const;
/*! \return The shape of the array */
Expand All @@ -49,7 +49,7 @@ class NDArray {
* \param obj The R NDArray object
* \return The external pointer to the object
*/
inline static Rcpp::XPtr<NDArray> XPtr(const Rcpp::RObject& obj);
inline static NDArray* XPtr(const Rcpp::RObject& obj);
/*!
* \brief Load a list of ndarray from the file.
* \param filename the name of the file.
Expand Down Expand Up @@ -170,6 +170,8 @@ class NDArrayFunction : public ::Rcpp::CppFunction {
} // namespace mxnet


RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray);

namespace mxnet {
namespace R {
// implementatins of inline functions
Expand All @@ -187,22 +189,21 @@ inline Rcpp::Dimension NDArray::shape() const {
inline NDArray::RObjectType NDArray::RObject(
NDArrayHandle handle,
bool writable) {
Rcpp::List ret = Rcpp::List::create(
Rcpp::Named("ptr") = Rcpp::XPtr<NDArray>(new NDArray(handle, writable)));
ret.attr("class") = "mx.NDArray";
return ret;
NDArray* p = new NDArray(handle, writable);
// TODO(KK) can we avoid use internal::make_new_object?
// The Wrap function requires NDArray object instead of ptr,
// Which will trigger destructor
return Rcpp::internal::make_new_object(p);
}

inline NDArray::RObjectType NDArray::Move(const Rcpp::RObject& src) {
Rcpp::XPtr<NDArray> old = NDArray::XPtr(src);
NDArray* old = NDArray::XPtr(src);
old->moved_ = true;
return NDArray::RObject(old->handle_, old->writable_);
}

inline Rcpp::XPtr<NDArray> NDArray::XPtr(const Rcpp::RObject& obj) {
Rcpp::List ret(obj);
Rcpp::RObject xptr = ret[0];
Rcpp::XPtr<NDArray> ptr(xptr);
inline NDArray* NDArray::XPtr(const Rcpp::RObject& obj) {
NDArray* ptr = Rcpp::as<NDArray*>(obj);
RCHECK(!ptr->moved_)
<< "Passed in a moved NDArray as parameters."
<< " Moved parameters should no longer be used";
Expand Down