Skip to content

Commit

Permalink
Merge pull request apache#69 from tqchen/master
Browse files Browse the repository at this point in the history
Fix GPU dot, add slice
  • Loading branch information
tqchen committed Oct 27, 2015
2 parents f2d0e25 + 3b40a9c commit 27ba6a6
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 2 deletions.
2 changes: 0 additions & 2 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ struct BLASEngine<gpu> {
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
CUBLAS_POINTER_MODE_HOST);
return ret;
}
inline static void dot(Stream<gpu> *stream,
int n,
Expand All @@ -248,7 +247,6 @@ struct BLASEngine<gpu> {
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
CUBLAS_POINTER_MODE_HOST);
return ret;
}
};
#endif // MSHADOW_USE_CUDA
Expand Down
1 change: 1 addition & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
#include "./extension/implicit_gemm.h"
#include "./extension/choose.h"
#include "./extension/one_hot.h"
#include "./extension/slice.h"
#endif // MSHADOW_EXTENSION_H_
156 changes: 156 additions & 0 deletions mshadow/extension/slice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*!
* Copyright (c) 2014 by Contributors
* \file slice.h
* \brief support for slice a certain dimension.
*/
#ifndef MSHADOW_EXTENSION_SLICE_H_
#define MSHADOW_EXTENSION_SLICE_H_

#include "../extension.h"

namespace mshadow {
namespace expr {
/*!
* \brief slice expression, slice a tensor's channel
* \tparam SrcExp left expression
* \tparam DType the type of elements
* \tparam srcdim dimension of src
* \tparam dimsrc_m_cat dimsrc - dimcat
*/
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct SliceExp : public TRValue<SliceExp<SrcExp,
Device, DType,
srcdim, dimsrc_m_slice>,
Device, srcdim, DType> {
static const int dimslice = srcdim - dimsrc_m_slice;
const SrcExp &src_;
index_t ch_begin_;
index_t ch_old_;
Shape<srcdim> shape_;
SliceExp(const SrcExp &src, index_t begin, index_t end)
: src_(src), ch_begin_(begin) {
shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
ch_old_ = shape_[dimslice];
CHECK(begin < shape_[dimslice] && end <= shape_[dimslice])
<< "The slice went out of range";
shape_[dimslice] = end - begin;
}
template<typename E, int etype>
inline void
operator=(const expr::Exp<E, DType, etype> &exp) {
this->__assign(exp);
}
inline void
operator=(const DType &exp) {
this->__assign(exp);
}
}; // struct Slice

/*!
* \brief Slice a Tensor
* \param src source tensor
* \param begin The beginning slice.
* \param end The end slice.
* \return sliced tensor
* \tparam sdim the dimension to slice on
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<int sdim, typename SrcExp,
typename Device, typename DType, int srcdim>
inline SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>
slice(const TRValue<SrcExp, Device, srcdim, DType> &src, index_t begin, index_t end) {
TypeCheckPass<sdim < srcdim && ExpInfo<SrcExp>::kDim == srcdim>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>(src.self(), begin, end);
}
//------------------------
// engine plugin
//------------------------
// runtime shapecheck
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct ShapeCheck<srcdim, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
inline static Shape<srcdim> Check(const SliceExp<SrcExp,
Device, DType, srcdim, dimsrc_m_slice> &t) {
return t.shape_;
}
};
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct StreamInfo<Device, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
inline static Stream<Device> *
Get(const SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> &t) {
return StreamInfo<Device, SrcExp>::Get(t.src_);
}
};
// static typecheck
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct ExpInfo<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
static const int kDim = ExpInfo<SrcExp>::kDim;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
//----------------------
// Execution plan
//---------------------
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct Plan<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice>, DType> {
public:
static const int dimslice = srcdim - dimsrc_m_slice;
explicit Plan(const SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> &e)
: src_(MakePlan(e.src_)),
height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)),
ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t y = i % height_;
i /= height_;
const index_t c = i % ch_ + ch_begin_;
const index_t b = i / ch_;
const index_t x = j;
return src_.Eval((b * ch_old_ + c) * height_ + y, x);
}
MSHADOW_XINLINE DType &REval(index_t i, index_t j) {
const index_t y = i % height_;
i /= height_;
const index_t c = i % ch_ + ch_begin_;
const index_t b = i / ch_;
const index_t x = j;
return src_.REval((b * ch_old_ + c) * height_ + y, x);
}

private:
Plan<SrcExp, DType> src_;
const index_t height_, ch_begin_, ch_old_, ch_;
}; // struct Plan

template<typename SrcExp,
typename Device, typename DType,
int srcdim>
struct Plan<SliceExp<SrcExp, Device, DType, srcdim, 1>, DType> {
public:
explicit Plan(const SliceExp<SrcExp, Device, DType, srcdim, 1> &e)
: src_(MakePlan(e.src_)),
ch_begin_(e.ch_begin_) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(y, x + ch_begin_);
}
MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
return src_.REval(y, x + ch_begin_);
}

private:
Plan<SrcExp, DType> src_;
const index_t ch_begin_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SLICE_H_

0 comments on commit 27ba6a6

Please sign in to comment.