diff --git a/guide/basic.cpp b/guide/basic.cpp index 7a15bf4092c3..b3ecd6efa1ac 100644 --- a/guide/basic.cpp +++ b/guide/basic.cpp @@ -78,6 +78,34 @@ int main(void) { printf("\n"); } printf("\n"); + TensorContainer idx(Shape1(3)); + idx[0] = 8; + idx[1] = 0; + idx[1] = 1; + + TensorContainer weight(Shape2(10, 5)); + TensorContainer embed(Shape2(3, 5)); + + for (index_t i = 0; i < weight.size(0); ++i) { + for (index_t j = 0; j < weight.size(1); ++j) { + weight[i][j] = i; + } + } + embed = take(idx, weight); + for (index_t i = 0; i < embed.size(0); ++i) { + for (index_t j = 0; j < embed.size(1); ++j) { + printf("%.2f ", embed[i][j]); + } + printf("\n"); + } + printf("\n\n"); + weight = take_grad(idx, embed, 10); + for (index_t i = 0; i < weight.size(0); ++i) { + for (index_t j = 0; j < weight.size(1); ++j) { + printf("%.2f ", weight[i][j]); + } + printf("\n"); + } // shutdown tensor enigne after usage ShutdownTensorEngine(); diff --git a/mshadow/extension.h b/mshadow/extension.h index 066d31ff51c1..52e151f7a3cd 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -26,4 +26,6 @@ #include "./extension/choose.h" #include "./extension/one_hot.h" #include "./extension/slice.h" +#include "./extension/take.h" +#include "./extension/take_grad.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/pack_col2patch.h b/mshadow/extension/pack_col2patch.h old mode 100755 new mode 100644 diff --git a/mshadow/extension/take.h b/mshadow/extension/take.h new file mode 100644 index 000000000000..f0b77b5700a5 --- /dev/null +++ b/mshadow/extension/take.h @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file take.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_TAKE_H_ +#define MSHADOW_EXTENSION_TAKE_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Take a column from a matrix + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ +template +struct TakeExp: public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief embediing oprand */ + const SrcExp &src_; + /*! constructor */ + TakeExp(const IndexExp &index, const SrcExp &src) + : index_(index), src_(src) {} +}; // struct TakeExp + + + +template +inline TakeExp +take(const Exp &index, + const Exp &src) { + return TakeExp(index.self(), src.self()); +} + + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const TakeExp &e) + : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { + } + + // TODO(xx): discuss W shape: in * out or out * in + // Now I use in * out + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, y)); + return static_cast(src_.Eval(idx, x)); + } + + private: + expr::Plan index_; + expr::Plan src_; +}; // struct Plan + +template +inline Plan, DType> +MakePlan(const TakeExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const TakeExp &t) { + CHECK(dim == 2) + << "TakeExp only support 2D output"; + Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape ret; + ret[0] = dshape[0]; + ret[1] = wshape[1]; + return ret; + } +}; + + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_TAKE_H_ diff --git a/mshadow/extension/take_grad.h b/mshadow/extension/take_grad.h new file mode 100644 index 000000000000..e7b7bb535dad --- /dev/null +++ b/mshadow/extension/take_grad.h @@ -0,0 +1,111 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file take_grad.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_ +#define MSHADOW_EXTENSION_TAKE_GRAD_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Calculate embedding gradient + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ + +template +struct TakeGradExp : public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief out gradient oprand */ + const SrcExp &src_; + /*! \brief batch size */ + const index_t input_dim_; + /*! \brief constructor */ + TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim) + : index_(index), src_(src), input_dim_(input_dim) {} +}; // struct TakeGradExp + + +template +inline TakeGradExp +take_grad(const Exp &index, + const Exp &src, + const index_t input_dim) { + return TakeGradExp(index.self(), + src.self(), + input_dim); +} + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const TakeGradExp &e) + : index_(MakePlan(e.index_)), + src_(MakePlan(e.src_)), + batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) { + } + + // now return shape: in * out + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + DType ret = 0.f; + for (index_t i = 0; i < batch_size_; ++i) { + index_t idx = static_cast(index_.Eval(0, i)); + if (idx == y) { + ret += static_cast(src_.Eval(i, x)); + } + } + return ret; + } + + private: + expr::Plan index_; + expr::Plan src_; + const index_t batch_size_; +}; // struct Plan + + +template +inline Plan, DType> +MakePlan(const TakeGradExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const TakeGradExp &t) { + CHECK(dim == 2) + << "TakeGradExp only support 2D output"; + // Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape ret; + ret[0] = t.input_dim_; + ret[1] = gshape[1]; + return ret; + } +}; // struct ShapeCheck + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_ diff --git a/mshadow/extension/unpack_patch2col.h b/mshadow/extension/unpack_patch2col.h old mode 100755 new mode 100644