From d61c768d865f59df6ff59695403a5dbc1aebc786 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 26 Oct 2015 20:47:31 -0700 Subject: [PATCH] Add slice --- mshadow/extension.h | 1 + mshadow/extension/slice.h | 156 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 mshadow/extension/slice.h diff --git a/mshadow/extension.h b/mshadow/extension.h index 336699f7dbe1..066d31ff51c1 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -25,4 +25,5 @@ #include "./extension/implicit_gemm.h" #include "./extension/choose.h" #include "./extension/one_hot.h" +#include "./extension/slice.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/slice.h b/mshadow/extension/slice.h new file mode 100644 index 000000000000..cb2eff4548aa --- /dev/null +++ b/mshadow/extension/slice.h @@ -0,0 +1,156 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file slice.h + * \brief support for slice a certain dimension. + */ +#ifndef MSHADOW_EXTENSION_SLICE_H_ +#define MSHADOW_EXTENSION_SLICE_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief slice expression, slice a tensor's channel + * \tparam SrcExp left expression + * \tparam DType the type of elements + * \tparam srcdim dimension of src + * \tparam dimsrc_m_cat dimsrc - dimcat + */ +template +struct SliceExp : public TRValue, + Device, srcdim, DType> { + static const int dimslice = srcdim - dimsrc_m_slice; + const SrcExp &src_; + index_t ch_begin_; + index_t ch_old_; + Shape shape_; + SliceExp(const SrcExp &src, index_t begin, index_t end) + : src_(src), ch_begin_(begin) { + shape_ = ShapeCheck::Check(src_); + ch_old_ = shape_[dimslice]; + CHECK(begin < shape_[dimslice] && end <= shape_[dimslice]) + << "The slice went out of range"; + shape_[dimslice] = end - begin; + } + template + inline void + operator=(const expr::Exp &exp) { + this->__assign(exp); + } + inline void + operator=(const DType &exp) { + this->__assign(exp); + } +}; // struct Slice + +/*! + * \brief Slice a Tensor + * \param src source tensor + * \param begin The beginning slice. + * \param end The end slice. + * \return sliced tensor + * \tparam sdim the dimension to slice on + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline SliceExp +slice(const TRValue &src, index_t begin, index_t end) { + TypeCheckPass::kDim == srcdim> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return SliceExp(src.self(), begin, end); +} +//------------------------ +// engine plugin +//------------------------ +// runtime shapecheck +template +struct ShapeCheck >{ + inline static Shape Check(const SliceExp &t) { + return t.shape_; + } +}; +template +struct StreamInfo >{ + inline static Stream * + Get(const SliceExp &t) { + return StreamInfo::Get(t.src_); + } +}; +// static typecheck +template +struct ExpInfo >{ + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +//---------------------- +// Execution plan +//--------------------- +template +struct Plan, DType> { + public: + static const int dimslice = srcdim - dimsrc_m_slice; + explicit Plan(const SliceExp &e) + : src_(MakePlan(e.src_)), + height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), + ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t y = i % height_; + i /= height_; + const index_t c = i % ch_ + ch_begin_; + const index_t b = i / ch_; + const index_t x = j; + return src_.Eval((b * ch_old_ + c) * height_ + y, x); + } + MSHADOW_XINLINE DType &REval(index_t i, index_t j) { + const index_t y = i % height_; + i /= height_; + const index_t c = i % ch_ + ch_begin_; + const index_t b = i / ch_; + const index_t x = j; + return src_.REval((b * ch_old_ + c) * height_ + y, x); + } + + private: + Plan src_; + const index_t height_, ch_begin_, ch_old_, ch_; +}; // struct Plan + +template +struct Plan, DType> { + public: + explicit Plan(const SliceExp &e) + : src_(MakePlan(e.src_)), + ch_begin_(e.ch_begin_) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(y, x + ch_begin_); + } + MSHADOW_XINLINE DType &REval(index_t y, index_t x) { + return src_.REval(y, x + ch_begin_); + } + + private: + Plan src_; + const index_t ch_begin_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SLICE_H_