From a92afc5a2d2667937507e94fdf0f2646cb2388bb Mon Sep 17 00:00:00 2001 From: sxjscience Date: Wed, 6 Jan 2016 11:14:59 +0800 Subject: [PATCH 1/2] Add new operator `MatFillRowElement`, which sets the value of a specific element in each line of the data matrix. --- mshadow/extension.h | 1 + mshadow/extension/fill.h | 102 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 mshadow/extension/fill.h diff --git a/mshadow/extension.h b/mshadow/extension.h index 01e11f1ba9c0..66df5546fcc9 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -24,6 +24,7 @@ #include "./extension/concat.h" #include "./extension/implicit_gemm.h" #include "./extension/choose.h" +#include "./extension/fill.h" #include "./extension/one_hot.h" #include "./extension/slice.h" #include "./extension/take.h" diff --git a/mshadow/extension/fill.h b/mshadow/extension/fill.h new file mode 100644 index 000000000000..a7b60d258210 --- /dev/null +++ b/mshadow/extension/fill.h @@ -0,0 +1,102 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fill.h + * \brief support for implicit array filling operation + * \author Xingjian Shi + */ +#ifndef MSHADOW_EXTENSION_FILL_H_ +#define MSHADOW_EXTENSION_FILL_H_ + +#include "../extension.h" + + +namespace mshadow { +namespace expr { +/*! + * \brief Set value of a specific element in each line of the data matrix. + * \tparam SrcExp type of src expression + * \tparam ValExp type of val expression + * \tparam IndexExp type of index expression + * \tparam DType the type of ret expression + */ +template +struct MatFillRowElementExp: + public Exp, + DType, type::kChainer> { + /*! \brief src operand */ + const SrcExp &src_; + const ValExp &val_; + /*! \brief index operand */ + const IndexExp &index_; + /*! \brief constructor */ + MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index) + : src_(src), val_(val), index_(index) {} +}; + +template +inline MatFillRowElementExp +mat_fill_row_element(const Exp &src, + const Exp &val, + const Exp &index) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1 && ExpInfo::kDim == 1> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MatFillRowElementExp(src.self(), val.self(), index.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const MatFillRowElementExp &e) + : src_(MakePlan(e.src_)), + val_(MakePlan(e.val_)), + index_(MakePlan(e.index_)) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, y)); + if(idx == x) { + return static_cast(val_.Eval(0, y)); + } + else { + return static_cast(src_.Eval(y, x)); + } + } + + private: + expr::Plan src_; + expr::Plan val_; + expr::Plan index_; +}; + +template +inline Plan, DType> +MakePlan(const MatFillRowElementExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MatFillRowElementExp &t) { + CHECK(dim == 2) + << "MatFillRowElementExp only support 2 dimension output"; + Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); + Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); + CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) + << "mat_fill_row_element index vector length, val vector length and number of rows in matrix"; + return shape_src; + } +}; + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_FILL_H_ From fa897bab4acf2ed5338f9bfb244b71a910409570 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Wed, 6 Jan 2016 14:27:16 +0800 Subject: [PATCH 2/2] Fix lint --- mshadow/extension/fill.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mshadow/extension/fill.h b/mshadow/extension/fill.h index a7b60d258210..4ac62c1673e5 100644 --- a/mshadow/extension/fill.h +++ b/mshadow/extension/fill.h @@ -39,9 +39,10 @@ inline MatFillRowElementExp mat_fill_row_element(const Exp &src, const Exp &val, const Exp &index) { - TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1 && ExpInfo::kDim == 1> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return MatFillRowElementExp(src.self(), val.self(), index.self()); + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1 + && ExpInfo::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MatFillRowElementExp(src.self(), + val.self(), index.self()); } //---------------------- @@ -57,10 +58,9 @@ struct Plan, DType> { } MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { index_t idx = static_cast(index_.Eval(0, y)); - if(idx == x) { + if (idx == x) { return static_cast(val_.Eval(0, y)); - } - else { + } else { return static_cast(src_.Eval(y, x)); } } @@ -87,7 +87,7 @@ struct ShapeCheck > { Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) - << "mat_fill_row_element index vector length, val vector length and number of rows in matrix"; + << "mat_fill_row_element index length, val length and number of rows in matrix"; return shape_src; } }; @@ -95,7 +95,8 @@ struct ShapeCheck > { template struct ExpInfo > { static const int kDim = 2; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; + static const int kDevMask = + ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; }; } // namespace expr } // namespace mshadow