Skip to content

Commit

Permalink
Merge pull request apache#92 from sxjscience/pr
Browse files Browse the repository at this point in the history
Add new operator `MatFillRowElement`
  • Loading branch information
winstywang committed Jan 7, 2016
2 parents 01ce2c5 + fa897ba commit 29863c5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
1 change: 1 addition & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "./extension/concat.h"
#include "./extension/implicit_gemm.h"
#include "./extension/choose.h"
#include "./extension/fill.h"
#include "./extension/one_hot.h"
#include "./extension/slice.h"
#include "./extension/take.h"
Expand Down
103 changes: 103 additions & 0 deletions mshadow/extension/fill.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*!
* Copyright (c) 2015 by Contributors
* \file fill.h
* \brief support for implicit array filling operation
* \author Xingjian Shi
*/
#ifndef MSHADOW_EXTENSION_FILL_H_
#define MSHADOW_EXTENSION_FILL_H_

#include "../extension.h"


namespace mshadow {
namespace expr {
/*!
* \brief Set value of a specific element in each line of the data matrix.
* \tparam SrcExp type of src expression
* \tparam ValExp type of val expression
* \tparam IndexExp type of index expression
* \tparam DType the type of ret expression
*/
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct MatFillRowElementExp:
public Exp<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>,
DType, type::kChainer> {
/*! \brief src operand */
const SrcExp &src_;
const ValExp &val_;
/*! \brief index operand */
const IndexExp &index_;
/*! \brief constructor */
MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index)
: src_(src), val_(val), index_(index) {}
};

template<typename SrcExp, typename ValExp, typename IndexExp,
typename SDType, typename VDType, typename IDType, int e1, int e2, int e3>
inline MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType>
mat_fill_row_element(const Exp<SrcExp, SDType, e1> &src,
const Exp<ValExp, VDType, e2> &val,
const Exp<IndexExp, IDType, e3> &index) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<ValExp>::kDim == 1
&& ExpInfo<IndexExp>::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req();
return MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType>(src.self(),
val.self(), index.self());
}

//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType> {
public:
explicit Plan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &e)
: src_(MakePlan(e.src_)),
val_(MakePlan(e.val_)),
index_(MakePlan(e.index_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, y));
if (idx == x) {
return static_cast<DType>(val_.Eval(0, y));
} else {
return static_cast<DType>(src_.Eval(y, x));
}
}

private:
expr::Plan<SrcExp, DType> src_;
expr::Plan<ValExp, DType> val_;
expr::Plan<IndexExp, DType> index_;
};

template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
inline Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType>
MakePlan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &exp) {
return Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType>(exp);
}

template<int dim, typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct ShapeCheck<dim, MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > {
inline static Shape<dim>
Check(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &t) {
CHECK(dim == 2)
<< "MatFillRowElementExp only support 2 dimension output";
Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_);
Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_);
CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0]))
<< "mat_fill_row_element index length, val length and number of rows in matrix";
return shape_src;
}
};

template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct ExpInfo<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > {
static const int kDim = 2;
static const int kDevMask =
ExpInfo<SrcExp>::kDevMask & ExpInfo<ValExp>::kDevMask & ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_FILL_H_

0 comments on commit 29863c5

Please sign in to comment.