diff --git a/mshadow/extension.h b/mshadow/extension.h index 7f821068834d..01e11f1ba9c0 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -28,5 +28,7 @@ #include "./extension/slice.h" #include "./extension/take.h" #include "./extension/take_grad.h" +#include "./extension/reduce_with_axis.h" +#include "./extension/broadcast_with_axis.h" #include "./extension/spatial_upsampling_nearest.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/broadcast_with_axis.h b/mshadow/extension/broadcast_with_axis.h new file mode 100644 index 000000000000..c10b907d2efe --- /dev/null +++ b/mshadow/extension/broadcast_with_axis.h @@ -0,0 +1,89 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file tensor_dot.h + * \brief + * \author Junyuan Xie +*/ +#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ +#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Backward for tensor dot + * \tparam DataExp type of left expression + * \tparam TopExp type of right expression + * \tparam DType data type + */ +template +struct BroadcastWithAxisExp: + public MakeTensorExp, + SrcExp, srcdim+1, DType> { + /*! \brief data oprand */ + const SrcExp &src_; + /*! \brief size of middle dimension */ + index_t leading_; + /*! \brief size of middle dimension */ + index_t trailing_; + /*! \brief size of middle dimension */ + index_t size_; + /*! \brief size of middle dimension */ + index_t last_; + /*! constructor */ + BroadcastWithAxisExp(const SrcExp &src, const index_t size) + : src_(src), size_(size) { + CHECK(srcdim > axis) << "broadcast axis out of bound"; + Shape src_shape = ShapeCheck::Check(src_); + this->leading_ = 1; + for (index_t i = 0; i <= axis; ++i) { + this->leading_ *= src_shape[i]; + this->shape_[i] = src_shape[i]; + } + this->shape_[axis+1] = size_; + this->trailing_ = 1; + for (index_t i = axis+1; i < srcdim; ++i) { + this->trailing_ *= src_shape[i]; + this->shape_[i+1] = src_shape[i]; + } + this->last_ = src_shape[srcdim-1]; + } +}; // struct BroadcastWithAxisExp + +/*! + * \brief pooling subregion results together + * \param data data oprand + * \param top top grad oprand + * \tparam DataExp left expression + * \tparam TopExp right expression + * \tparam DType the content data type + */ +template +inline BroadcastWithAxisExp::kDim> +broadcast_with_axis(const Exp &src, const index_t size) { + return BroadcastWithAxisExp::kDim>(src.self(), size); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const BroadcastWithAxisExp &e) + : src_(MakePlan(e.src_)), leading_(e.leading_), + trailing_(e.trailing_), size_(e.size_), last_(e.last_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t x = (i*last_+j)/trailing_/size_; + index_t y = (i*last_+j)%trailing_; + index_t z = x*trailing_ + y; + return src_.Eval(z/last_, z%last_); + } + + private: + Plan src_; + const index_t leading_, trailing_, size_, last_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ diff --git a/mshadow/extension/reduce_with_axis.h b/mshadow/extension/reduce_with_axis.h new file mode 100644 index 000000000000..b1c090ea7fa0 --- /dev/null +++ b/mshadow/extension/reduce_with_axis.h @@ -0,0 +1,94 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file reduce_with_axis.h + * \brief + * \author Junyuan Xie +*/ +#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ +#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief reduce out the dimension of src labeled by axis. + * \tparam Reducer type of reducer + * \tparam SrcExp type of source expression + * \tparam DType data type + */ +template +struct ReduceWithAxisExp: + public MakeTensorExp, + SrcExp, srcdim-1, DType> { + /*! \brief source oprand */ + const SrcExp &src_; + /*! \brief size of leading dimensions */ + index_t leading_; + /*! \brief size of trailing dimensions */ + index_t trailing_; + /*! \brief size of axis dimension */ + index_t size_; + /*! \brief size of last src dimension */ + index_t last_; + /*! constructor */ + explicit ReduceWithAxisExp(const SrcExp &src) + : src_(src) { + CHECK(srcdim > axis) << "reduce axis out of bound"; + Shape src_shape = ShapeCheck::Check(src_); + this->leading_ = 1; + for (index_t i = 0; i < axis; ++i) { + this->leading_ *= src_shape[i]; + this->shape_[i] = src_shape[i]; + } + this->size_ = src_shape[axis]; + this->trailing_ = 1; + for (index_t i = axis + 1; i < srcdim; ++i) { + this->trailing_ *= src_shape[i]; + this->shape_[i-1] = src_shape[i]; + } + this->last_ = src_shape[srcdim-1]; + } +}; // struct ReduceWithAxisExp + +/*! + * \brief pooling subregion results together + * \param lhs left oprand + * \param rhs right oprand + * \tparam LhsExp left expression + * \tparam RhsExp right expression + * \tparam DType the content data type + */ +template +inline ReduceWithAxisExp::kDim> +reduce_with_axis(const Exp &src) { + return ReduceWithAxisExp::kDim>(src.self()); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ReduceWithAxisExp &e) + : src_(MakePlan(e.src_)), leading_(e.leading_), trailing_(e.trailing_), + size_(e.size_), last_(e.last_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t x = (i*last_ + j)/trailing_; + index_t y = (i*last_ + j)%trailing_; + + DType res; Reducer::SetInitValue(res); + for (index_t k = 0; k < size_; ++k) { + index_t z = (x*size_+k)*trailing_+y; + Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); + } + return res; + } + + private: + Plan src_; + const index_t leading_, trailing_, size_, last_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_