Skip to content

Commit

Permalink
Merge pull request apache#85 from piiswrong/master
Browse files Browse the repository at this point in the history
broadcast and reduce with axis
  • Loading branch information
piiswrong committed Dec 13, 2015
2 parents ea75d09 + 629d09f commit 258d0b5
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@
#include "./extension/slice.h"
#include "./extension/take.h"
#include "./extension/take_grad.h"
#include "./extension/reduce_with_axis.h"
#include "./extension/broadcast_with_axis.h"
#include "./extension/spatial_upsampling_nearest.h"
#endif // MSHADOW_EXTENSION_H_
89 changes: 89 additions & 0 deletions mshadow/extension/broadcast_with_axis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*!
* Copyright (c) 2015 by Contributors
* \file tensor_dot.h
* \brief
* \author Junyuan Xie
*/
#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_

#include "../extension.h"

namespace mshadow {
namespace expr {

/*! \brief Backward for tensor dot
* \tparam DataExp type of left expression
* \tparam TopExp type of right expression
* \tparam DType data type
*/
template<int axis, typename SrcExp, typename DType, int srcdim>
struct BroadcastWithAxisExp:
public MakeTensorExp<BroadcastWithAxisExp<axis, SrcExp, DType, srcdim>,
SrcExp, srcdim+1, DType> {
/*! \brief data oprand */
const SrcExp &src_;
/*! \brief size of middle dimension */
index_t leading_;
/*! \brief size of middle dimension */
index_t trailing_;
/*! \brief size of middle dimension */
index_t size_;
/*! \brief size of middle dimension */
index_t last_;
/*! constructor */
BroadcastWithAxisExp(const SrcExp &src, const index_t size)
: src_(src), size_(size) {
CHECK(srcdim > axis) << "broadcast axis out of bound";
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->leading_ = 1;
for (index_t i = 0; i <= axis; ++i) {
this->leading_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
this->shape_[axis+1] = size_;
this->trailing_ = 1;
for (index_t i = axis+1; i < srcdim; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i+1] = src_shape[i];
}
this->last_ = src_shape[srcdim-1];
}
}; // struct BroadcastWithAxisExp

/*!
* \brief pooling subregion results together
* \param data data oprand
* \param top top grad oprand
* \tparam DataExp left expression
* \tparam TopExp right expression
* \tparam DType the content data type
*/
template<int axis, typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<axis, SrcExp, DType, ExpInfo<SrcExp>::kDim>
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const index_t size) {
return BroadcastWithAxisExp<axis, SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), size);
}
//----------------------
// Execution plan
//----------------------
template<int axis, typename SrcExp, typename DType, int srcdim>
struct Plan<BroadcastWithAxisExp<axis, SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const BroadcastWithAxisExp<axis, SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)), leading_(e.leading_),
trailing_(e.trailing_), size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t x = (i*last_+j)/trailing_/size_;
index_t y = (i*last_+j)%trailing_;
index_t z = x*trailing_ + y;
return src_.Eval(z/last_, z%last_);
}

private:
Plan<SrcExp, DType> src_;
const index_t leading_, trailing_, size_, last_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
94 changes: 94 additions & 0 deletions mshadow/extension/reduce_with_axis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*!
* Copyright (c) 2015 by Contributors
* \file reduce_with_axis.h
* \brief
* \author Junyuan Xie
*/
#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_

#include "../extension.h"

namespace mshadow {
namespace expr {

/*! \brief reduce out the dimension of src labeled by axis.
* \tparam Reducer type of reducer
* \tparam SrcExp type of source expression
* \tparam DType data type
*/
template<typename Reducer, int axis, typename SrcExp, typename DType, int srcdim>
struct ReduceWithAxisExp:
public MakeTensorExp<ReduceWithAxisExp<Reducer, axis, SrcExp, DType, srcdim>,
SrcExp, srcdim-1, DType> {
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief size of leading dimensions */
index_t leading_;
/*! \brief size of trailing dimensions */
index_t trailing_;
/*! \brief size of axis dimension */
index_t size_;
/*! \brief size of last src dimension */
index_t last_;
/*! constructor */
explicit ReduceWithAxisExp(const SrcExp &src)
: src_(src) {
CHECK(srcdim > axis) << "reduce axis out of bound";
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->leading_ = 1;
for (index_t i = 0; i < axis; ++i) {
this->leading_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
this->size_ = src_shape[axis];
this->trailing_ = 1;
for (index_t i = axis + 1; i < srcdim; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i-1] = src_shape[i];
}
this->last_ = src_shape[srcdim-1];
}
}; // struct ReduceWithAxisExp

/*!
* \brief pooling subregion results together
* \param lhs left oprand
* \param rhs right oprand
* \tparam LhsExp left expression
* \tparam RhsExp right expression
* \tparam DType the content data type
*/
template<typename Reducer, int axis, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, axis, SrcExp, DType, ExpInfo<SrcExp>::kDim>
reduce_with_axis(const Exp<SrcExp, DType, etype> &src) {
return ReduceWithAxisExp<Reducer, axis, SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self());
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, int axis, typename SrcExp, typename DType, int srcdim>
struct Plan<ReduceWithAxisExp<Reducer, axis, SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const ReduceWithAxisExp<Reducer, axis, SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)), leading_(e.leading_), trailing_(e.trailing_),
size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t x = (i*last_ + j)/trailing_;
index_t y = (i*last_ + j)%trailing_;

DType res; Reducer::SetInitValue(res);
for (index_t k = 0; k < size_; ++k) {
index_t z = (x*size_+k)*trailing_+y;
Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
}
return res;
}

private:
Plan<SrcExp, DType> src_;
const index_t leading_, trailing_, size_, last_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_

0 comments on commit 258d0b5

Please sign in to comment.