From 9c37c2ba1f89ffe51f74b87214165aa154d594d9 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Thu, 19 Nov 2015 22:44:00 -0700 Subject: [PATCH] take grad --- guide/basic.cpp | 8 +++ mshadow/extension.h | 1 + mshadow/extension/take_grad.h | 111 ++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) create mode 100644 mshadow/extension/take_grad.h diff --git a/guide/basic.cpp b/guide/basic.cpp index fe9d571a6504..b3ecd6efa1ac 100644 --- a/guide/basic.cpp +++ b/guide/basic.cpp @@ -98,6 +98,14 @@ int main(void) { } printf("\n"); } + printf("\n\n"); + weight = take_grad(idx, embed, 10); + for (index_t i = 0; i < weight.size(0); ++i) { + for (index_t j = 0; j < weight.size(1); ++j) { + printf("%.2f ", weight[i][j]); + } + printf("\n"); + } // shutdown tensor enigne after usage ShutdownTensorEngine(); diff --git a/mshadow/extension.h b/mshadow/extension.h index b9697d945cdf..52e151f7a3cd 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -27,4 +27,5 @@ #include "./extension/one_hot.h" #include "./extension/slice.h" #include "./extension/take.h" +#include "./extension/take_grad.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/take_grad.h b/mshadow/extension/take_grad.h new file mode 100644 index 000000000000..e7b7bb535dad --- /dev/null +++ b/mshadow/extension/take_grad.h @@ -0,0 +1,111 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file take_grad.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_ +#define MSHADOW_EXTENSION_TAKE_GRAD_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Calculate embedding gradient + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ + +template +struct TakeGradExp : public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief out gradient oprand */ + const SrcExp &src_; + /*! \brief batch size */ + const index_t input_dim_; + /*! \brief constructor */ + TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim) + : index_(index), src_(src), input_dim_(input_dim) {} +}; // struct TakeGradExp + + +template +inline TakeGradExp +take_grad(const Exp &index, + const Exp &src, + const index_t input_dim) { + return TakeGradExp(index.self(), + src.self(), + input_dim); +} + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const TakeGradExp &e) + : index_(MakePlan(e.index_)), + src_(MakePlan(e.src_)), + batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) { + } + + // now return shape: in * out + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + DType ret = 0.f; + for (index_t i = 0; i < batch_size_; ++i) { + index_t idx = static_cast(index_.Eval(0, i)); + if (idx == y) { + ret += static_cast(src_.Eval(i, x)); + } + } + return ret; + } + + private: + expr::Plan index_; + expr::Plan src_; + const index_t batch_size_; +}; // struct Plan + + +template +inline Plan, DType> +MakePlan(const TakeGradExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const TakeGradExp &t) { + CHECK(dim == 2) + << "TakeGradExp only support 2D output"; + // Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape ret; + ret[0] = t.input_dim_; + ret[1] = gshape[1]; + return ret; + } +}; // struct ShapeCheck + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_