Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #221 from tqchen/R
Browse files Browse the repository at this point in the history
[R] allow operator overloading
  • Loading branch information
tqchen committed Oct 7, 2015
2 parents be460fb + 5d9da0f commit 5ce672d
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 21 deletions.
9 changes: 4 additions & 5 deletions R-package/R/base.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
require(methods)

loadModule("mxnet", TRUE)
setOldClass("mx.NDArray")
setMethod("+", signature(e1="mx.NDArray", e2="numeric"), function(e1, e2) {
mx.nd.internal.plus.scalar(e1, e2)
})
.onLoad <- function(libname, pkgname) {
loadModule("mxnet", TRUE)
init.ndarray.methods()
}
11 changes: 9 additions & 2 deletions R-package/R/ndarray.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#' NDArray
#'
#' Additional NDArray related operations
require(methods)

init.ndarray.methods <-function() {
require(methods)
setMethod("+", signature(e1="Rcpp_MXNDArray", e2="numeric"), function(e1, e2) {
mx.nd.internal.plus.scalar(e1, e2)
})
setMethod("+", signature(e1="Rcpp_MXNDArray", e2="Rcpp_MXNDArray"), function(e1, e2) {
mx.nd.internal.plus(e1, e2)
})
}
9 changes: 6 additions & 3 deletions R-package/demo/basic_ndarray.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
require(mxnet)
require(methods)
x = as.array(c(1,2,3))


x = as.array(c(1,2,3))
mat = mx.nd.array(x, mx.cpu(0))
mat = mx.nd.internal.plus(mat, mat)
mat = mat + 1.0
mat = mat + mat

xx = mx.nd.internal.as.array(mat)
print(class(mat))

print(xx)

1 change: 1 addition & 0 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ NDArray::RObjectType NDArray::Array(
// register normal function.
void NDArray::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
class_<NDArray>("MXNDArray");
function("mx.nd.load", &NDArray::Load);
function("mx.nd.save", &NDArray::Save);
function("mx.nd.array", &NDArray::Array);
Expand Down
23 changes: 12 additions & 11 deletions R-package/src/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class NDArrayFunction;
class NDArray {
public:
/*! \brief The type of NDArray in R's side */
typedef Rcpp::List RObjectType;
typedef Rcpp::RObject RObjectType;
/*! \return convert the NDArray to R's Array */
Rcpp::NumericVector AsNumericVector() const;
/*! \return The shape of the array */
Expand All @@ -49,7 +49,7 @@ class NDArray {
* \param obj The R NDArray object
* \return The external pointer to the object
*/
inline static Rcpp::XPtr<NDArray> XPtr(const Rcpp::RObject& obj);
inline static NDArray* XPtr(const Rcpp::RObject& obj);
/*!
* \brief Load a list of ndarray from the file.
* \param filename the name of the file.
Expand Down Expand Up @@ -170,6 +170,8 @@ class NDArrayFunction : public ::Rcpp::CppFunction {
} // namespace mxnet


RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray);

namespace mxnet {
namespace R {
// implementatins of inline functions
Expand All @@ -187,22 +189,21 @@ inline Rcpp::Dimension NDArray::shape() const {
inline NDArray::RObjectType NDArray::RObject(
NDArrayHandle handle,
bool writable) {
Rcpp::List ret = Rcpp::List::create(
Rcpp::Named("ptr") = Rcpp::XPtr<NDArray>(new NDArray(handle, writable)));
ret.attr("class") = "mx.NDArray";
return ret;
NDArray* p = new NDArray(handle, writable);
// TODO(KK) can we avoid use internal::make_new_object?
// The Wrap function requires NDArray object instead of ptr,
// Which will trigger destructor
return Rcpp::internal::make_new_object(p);
}

inline NDArray::RObjectType NDArray::Move(const Rcpp::RObject& src) {
Rcpp::XPtr<NDArray> old = NDArray::XPtr(src);
NDArray* old = NDArray::XPtr(src);
old->moved_ = true;
return NDArray::RObject(old->handle_, old->writable_);
}

inline Rcpp::XPtr<NDArray> NDArray::XPtr(const Rcpp::RObject& obj) {
Rcpp::List ret(obj);
Rcpp::RObject xptr = ret[0];
Rcpp::XPtr<NDArray> ptr(xptr);
inline NDArray* NDArray::XPtr(const Rcpp::RObject& obj) {
NDArray* ptr = Rcpp::as<NDArray*>(obj);
RCHECK(!ptr->moved_)
<< "Passed in a moved NDArray as parameters."
<< " Moved parameters should no longer be used";
Expand Down

0 comments on commit 5ce672d

Please sign in to comment.