diff --git a/guide/basic.cpp b/guide/basic.cpp index 24437bd290bd..73fd15661cb3 100644 --- a/guide/basic.cpp +++ b/guide/basic.cpp @@ -41,6 +41,7 @@ int main(void) { lhs = 1.0; rhs = 1.0; ret = implicit_dot(lhs, rhs.T()); + int cnt = 0; for (index_t i = 0; i < ret.size(0); ++i) { for (index_t j = 0; j < ret.size(1); ++j) { printf("%.2f ", ret[i][j]); @@ -48,6 +49,24 @@ int main(void) { printf("\n"); } + printf("\n"); + + for (index_t i = 0; i < lhs.size(0); ++i) { + for (index_t j = 0; j < lhs.size(1); ++j) { + lhs[i][j] = cnt++; + printf("%.2f ", lhs[i][j]); + } + printf("\n"); + } + printf("\n"); + TensorContainer index(Shape1(2)), choosed(Shape1(2)); + index[0] = 1; index[1] = 2; + choosed = mat_choose_row_element(lhs, index); + for (index_t i = 0; i < choosed.size(0); ++i) { + printf("%.2f ", choosed[i]); + } + printf("\n "); + // shutdown tensor enigne after usage ShutdownTensorEngine(); return 0; diff --git a/mshadow/extension.h b/mshadow/extension.h index d16526c8db65..702fd1913a52 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -23,4 +23,5 @@ #include "./extension/mirror.h" #include "./extension/concat.h" #include "./extension/implicit_gemm.h" +#include "./extension/choose.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/choose.h b/mshadow/extension/choose.h new file mode 100644 index 000000000000..bb4482faf8db --- /dev/null +++ b/mshadow/extension/choose.h @@ -0,0 +1,93 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file choose.h + * \brief support for implicit array selection operation + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_CHOOSE_H_ +#define MSHADOW_EXTENSION_CHOOSE_H_ + +#include "../extension.h" +#include "../packet-inl.h" + +namespace mshadow { +namespace expr { +/*! + * \brief Make a choice of index in the lowest changing dimension. + * \tparam SrcExp type of lhs expression + * \tparam IndexExp type of index expression + * \tparam DType the type of elements + */ +template +struct MatChooseRowElementExp: + public Exp, + DType, type::kChainer> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief index operand */ + const IndexExp &index_; + /*! \brief constructor */ + MatChooseRowElementExp(const SrcExp &src, const IndexExp &index) + : src_(src), index_(index) {} +}; + +template +inline MatChooseRowElementExp +mat_choose_row_element(const Exp &src, + const Exp &index) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MatChooseRowElementExp(src.self(), index.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const MatChooseRowElementExp &e) + : src_(MakePlan(e.src_)), + index_(MakePlan(e.index_)) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, x)); + return src_.Eval(x, idx); + } + + private: + expr::Plan src_; + expr::Plan index_; +}; + +template +inline Plan, DType> +MakePlan(const MatChooseRowElementExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MatChooseRowElementExp &t) { + CHECK(dim == 1) + << "MatChooseRowElementExp only support 1 dimension output"; + Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape shape2 = ShapeCheck::Check(t.index_); + CHECK_EQ(shape1[0], shape2[0]) + << "mat_choose_row_element index length and number of rows in matrix"; + return shape2; + } +}; + +template +struct ExpInfo > { + static const int kDim = 1; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_CHOOSE_H_ + +