forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#85 from piiswrong/master
broadcast and reduce with axis
- Loading branch information
Showing
3 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file tensor_dot.h | ||
* \brief | ||
* \author Junyuan Xie | ||
*/ | ||
#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ | ||
#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ | ||
|
||
#include "../extension.h" | ||
|
||
namespace mshadow { | ||
namespace expr { | ||
|
||
/*! \brief Backward for tensor dot | ||
* \tparam DataExp type of left expression | ||
* \tparam TopExp type of right expression | ||
* \tparam DType data type | ||
*/ | ||
template<int axis, typename SrcExp, typename DType, int srcdim> | ||
struct BroadcastWithAxisExp: | ||
public MakeTensorExp<BroadcastWithAxisExp<axis, SrcExp, DType, srcdim>, | ||
SrcExp, srcdim+1, DType> { | ||
/*! \brief data oprand */ | ||
const SrcExp &src_; | ||
/*! \brief size of middle dimension */ | ||
index_t leading_; | ||
/*! \brief size of middle dimension */ | ||
index_t trailing_; | ||
/*! \brief size of middle dimension */ | ||
index_t size_; | ||
/*! \brief size of middle dimension */ | ||
index_t last_; | ||
/*! constructor */ | ||
BroadcastWithAxisExp(const SrcExp &src, const index_t size) | ||
: src_(src), size_(size) { | ||
CHECK(srcdim > axis) << "broadcast axis out of bound"; | ||
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_); | ||
this->leading_ = 1; | ||
for (index_t i = 0; i <= axis; ++i) { | ||
this->leading_ *= src_shape[i]; | ||
this->shape_[i] = src_shape[i]; | ||
} | ||
this->shape_[axis+1] = size_; | ||
this->trailing_ = 1; | ||
for (index_t i = axis+1; i < srcdim; ++i) { | ||
this->trailing_ *= src_shape[i]; | ||
this->shape_[i+1] = src_shape[i]; | ||
} | ||
this->last_ = src_shape[srcdim-1]; | ||
} | ||
}; // struct BroadcastWithAxisExp | ||
|
||
/*! | ||
* \brief pooling subregion results together | ||
* \param data data oprand | ||
* \param top top grad oprand | ||
* \tparam DataExp left expression | ||
* \tparam TopExp right expression | ||
* \tparam DType the content data type | ||
*/ | ||
template<int axis, typename SrcExp, typename DType, int etype> | ||
inline BroadcastWithAxisExp<axis, SrcExp, DType, ExpInfo<SrcExp>::kDim> | ||
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const index_t size) { | ||
return BroadcastWithAxisExp<axis, SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), size); | ||
} | ||
//---------------------- | ||
// Execution plan | ||
//---------------------- | ||
template<int axis, typename SrcExp, typename DType, int srcdim> | ||
struct Plan<BroadcastWithAxisExp<axis, SrcExp, DType, srcdim>, DType> { | ||
public: | ||
explicit Plan(const BroadcastWithAxisExp<axis, SrcExp, DType, srcdim> &e) | ||
: src_(MakePlan(e.src_)), leading_(e.leading_), | ||
trailing_(e.trailing_), size_(e.size_), last_(e.last_) {} | ||
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | ||
index_t x = (i*last_+j)/trailing_/size_; | ||
index_t y = (i*last_+j)%trailing_; | ||
index_t z = x*trailing_ + y; | ||
return src_.Eval(z/last_, z%last_); | ||
} | ||
|
||
private: | ||
Plan<SrcExp, DType> src_; | ||
const index_t leading_, trailing_, size_, last_; | ||
}; | ||
} // namespace expr | ||
} // namespace mshadow | ||
#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file reduce_with_axis.h | ||
* \brief | ||
* \author Junyuan Xie | ||
*/ | ||
#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ | ||
#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ | ||
|
||
#include "../extension.h" | ||
|
||
namespace mshadow { | ||
namespace expr { | ||
|
||
/*! \brief reduce out the dimension of src labeled by axis. | ||
* \tparam Reducer type of reducer | ||
* \tparam SrcExp type of source expression | ||
* \tparam DType data type | ||
*/ | ||
template<typename Reducer, int axis, typename SrcExp, typename DType, int srcdim> | ||
struct ReduceWithAxisExp: | ||
public MakeTensorExp<ReduceWithAxisExp<Reducer, axis, SrcExp, DType, srcdim>, | ||
SrcExp, srcdim-1, DType> { | ||
/*! \brief source oprand */ | ||
const SrcExp &src_; | ||
/*! \brief size of leading dimensions */ | ||
index_t leading_; | ||
/*! \brief size of trailing dimensions */ | ||
index_t trailing_; | ||
/*! \brief size of axis dimension */ | ||
index_t size_; | ||
/*! \brief size of last src dimension */ | ||
index_t last_; | ||
/*! constructor */ | ||
explicit ReduceWithAxisExp(const SrcExp &src) | ||
: src_(src) { | ||
CHECK(srcdim > axis) << "reduce axis out of bound"; | ||
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_); | ||
this->leading_ = 1; | ||
for (index_t i = 0; i < axis; ++i) { | ||
this->leading_ *= src_shape[i]; | ||
this->shape_[i] = src_shape[i]; | ||
} | ||
this->size_ = src_shape[axis]; | ||
this->trailing_ = 1; | ||
for (index_t i = axis + 1; i < srcdim; ++i) { | ||
this->trailing_ *= src_shape[i]; | ||
this->shape_[i-1] = src_shape[i]; | ||
} | ||
this->last_ = src_shape[srcdim-1]; | ||
} | ||
}; // struct ReduceWithAxisExp | ||
|
||
/*! | ||
* \brief pooling subregion results together | ||
* \param lhs left oprand | ||
* \param rhs right oprand | ||
* \tparam LhsExp left expression | ||
* \tparam RhsExp right expression | ||
* \tparam DType the content data type | ||
*/ | ||
template<typename Reducer, int axis, typename SrcExp, typename DType, int etype> | ||
inline ReduceWithAxisExp<Reducer, axis, SrcExp, DType, ExpInfo<SrcExp>::kDim> | ||
reduce_with_axis(const Exp<SrcExp, DType, etype> &src) { | ||
return ReduceWithAxisExp<Reducer, axis, SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self()); | ||
} | ||
//---------------------- | ||
// Execution plan | ||
//---------------------- | ||
template<typename Reducer, int axis, typename SrcExp, typename DType, int srcdim> | ||
struct Plan<ReduceWithAxisExp<Reducer, axis, SrcExp, DType, srcdim>, DType> { | ||
public: | ||
explicit Plan(const ReduceWithAxisExp<Reducer, axis, SrcExp, DType, srcdim> &e) | ||
: src_(MakePlan(e.src_)), leading_(e.leading_), trailing_(e.trailing_), | ||
size_(e.size_), last_(e.last_) {} | ||
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | ||
index_t x = (i*last_ + j)/trailing_; | ||
index_t y = (i*last_ + j)%trailing_; | ||
|
||
DType res; Reducer::SetInitValue(res); | ||
for (index_t k = 0; k < size_; ++k) { | ||
index_t z = (x*size_+k)*trailing_+y; | ||
Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); | ||
} | ||
return res; | ||
} | ||
|
||
private: | ||
Plan<SrcExp, DType> src_; | ||
const index_t leading_, trailing_, size_, last_; | ||
}; | ||
} // namespace expr | ||
} // namespace mshadow | ||
#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ |