diff --git a/mshadow/extension.h b/mshadow/extension.h index 01e11f1ba9c0..66df5546fcc9 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -24,6 +24,7 @@ #include "./extension/concat.h" #include "./extension/implicit_gemm.h" #include "./extension/choose.h" +#include "./extension/fill.h" #include "./extension/one_hot.h" #include "./extension/slice.h" #include "./extension/take.h" diff --git a/mshadow/extension/fill.h b/mshadow/extension/fill.h new file mode 100644 index 000000000000..4ac62c1673e5 --- /dev/null +++ b/mshadow/extension/fill.h @@ -0,0 +1,103 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fill.h + * \brief support for implicit array filling operation + * \author Xingjian Shi + */ +#ifndef MSHADOW_EXTENSION_FILL_H_ +#define MSHADOW_EXTENSION_FILL_H_ + +#include "../extension.h" + + +namespace mshadow { +namespace expr { +/*! + * \brief Set value of a specific element in each line of the data matrix. + * \tparam SrcExp type of src expression + * \tparam ValExp type of val expression + * \tparam IndexExp type of index expression + * \tparam DType the type of ret expression + */ +template +struct MatFillRowElementExp: + public Exp, + DType, type::kChainer> { + /*! \brief src operand */ + const SrcExp &src_; + const ValExp &val_; + /*! \brief index operand */ + const IndexExp &index_; + /*! \brief constructor */ + MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index) + : src_(src), val_(val), index_(index) {} +}; + +template +inline MatFillRowElementExp +mat_fill_row_element(const Exp &src, + const Exp &val, + const Exp &index) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1 + && ExpInfo::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MatFillRowElementExp(src.self(), + val.self(), index.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const MatFillRowElementExp &e) + : src_(MakePlan(e.src_)), + val_(MakePlan(e.val_)), + index_(MakePlan(e.index_)) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, y)); + if (idx == x) { + return static_cast(val_.Eval(0, y)); + } else { + return static_cast(src_.Eval(y, x)); + } + } + + private: + expr::Plan src_; + expr::Plan val_; + expr::Plan index_; +}; + +template +inline Plan, DType> +MakePlan(const MatFillRowElementExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MatFillRowElementExp &t) { + CHECK(dim == 2) + << "MatFillRowElementExp only support 2 dimension output"; + Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); + Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); + CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) + << "mat_fill_row_element index length, val length and number of rows in matrix"; + return shape_src; + } +}; + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = + ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_FILL_H_