forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#73 from antinucleon/master
Take and Take Grad
- Loading branch information
Showing
6 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file take.h | ||
* \brief | ||
* \author Bing Xu | ||
*/ | ||
#ifndef MSHADOW_EXTENSION_TAKE_H_ | ||
#define MSHADOW_EXTENSION_TAKE_H_ | ||
|
||
#include "../extension.h" | ||
|
||
namespace mshadow { | ||
namespace expr { | ||
|
||
/*! \brief Take a column from a matrix | ||
* \tparam IndexExp type of index expression | ||
* \tparam SrcExp type of src expression | ||
* \tparam DType data type | ||
*/ | ||
template<typename IndexExp, typename SrcExp, typename DType> | ||
struct TakeExp: public Exp<TakeExp<IndexExp, SrcExp, DType>, | ||
DType, type::kChainer> { | ||
/*! \brief index oprand */ | ||
const IndexExp &index_; | ||
/*! \brief embediing oprand */ | ||
const SrcExp &src_; | ||
/*! constructor */ | ||
TakeExp(const IndexExp &index, const SrcExp &src) | ||
: index_(index), src_(src) {} | ||
}; // struct TakeExp | ||
|
||
|
||
|
||
template<typename IndexExp, | ||
typename SrcExp, | ||
typename DType, | ||
int e1, int e2> | ||
inline TakeExp<IndexExp, SrcExp, default_real_t> | ||
take(const Exp<IndexExp, DType, e1> &index, | ||
const Exp<SrcExp, DType, e2> &src) { | ||
return TakeExp<IndexExp, SrcExp, default_real_t>(index.self(), src.self()); | ||
} | ||
|
||
|
||
//---------------------- | ||
// Execution plan | ||
//---------------------- | ||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
struct Plan<TakeExp<IndexExp, SrcExp, DType>, DType> { | ||
public: | ||
explicit Plan(const TakeExp<IndexExp, SrcExp, DType> &e) | ||
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { | ||
} | ||
|
||
// TODO(xx): discuss W shape: in * out or out * in | ||
// Now I use in * out | ||
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | ||
index_t idx = static_cast<index_t>(index_.Eval(0, y)); | ||
return static_cast<DType>(src_.Eval(idx, x)); | ||
} | ||
|
||
private: | ||
expr::Plan<IndexExp, DType> index_; | ||
expr::Plan<SrcExp, DType> src_; | ||
}; // struct Plan | ||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
inline Plan<TakeExp<IndexExp, SrcExp, DType>, DType> | ||
MakePlan(const TakeExp<IndexExp, SrcExp, DType> &exp) { | ||
return Plan<TakeExp<IndexExp, SrcExp, DType>, DType>(exp); | ||
} | ||
|
||
template<int dim, typename IndexExp, typename SrcExp, typename DType> | ||
struct ShapeCheck<dim, TakeExp<IndexExp, SrcExp, DType> > { | ||
inline static Shape<dim> | ||
Check(const TakeExp<IndexExp, SrcExp, DType> &t) { | ||
CHECK(dim == 2) | ||
<< "TakeExp only support 2D output"; | ||
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); | ||
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); | ||
Shape<dim> ret; | ||
ret[0] = dshape[0]; | ||
ret[1] = wshape[1]; | ||
return ret; | ||
} | ||
}; | ||
|
||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
struct ExpInfo<TakeExp<IndexExp, SrcExp, DType> > { | ||
static const int kDim = 2; | ||
static const int kDevMask = ExpInfo<IndexExp>::kDevMask; | ||
}; | ||
|
||
} // namespace expr | ||
} // namespace mshadow | ||
|
||
#endif // MSHADOW_EXTENSION_TAKE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file take_grad.h | ||
* \brief | ||
* \author Bing Xu | ||
*/ | ||
#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_ | ||
#define MSHADOW_EXTENSION_TAKE_GRAD_H_ | ||
|
||
#include "../extension.h" | ||
|
||
namespace mshadow { | ||
namespace expr { | ||
|
||
/*! \brief Calculate embedding gradient | ||
* \tparam IndexExp type of index expression | ||
* \tparam SrcExp type of src expression | ||
* \tparam DType data type | ||
*/ | ||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
struct TakeGradExp : public Exp<TakeGradExp<IndexExp, SrcExp, DType>, | ||
DType, type::kChainer> { | ||
/*! \brief index oprand */ | ||
const IndexExp &index_; | ||
/*! \brief out gradient oprand */ | ||
const SrcExp &src_; | ||
/*! \brief batch size */ | ||
const index_t input_dim_; | ||
/*! \brief constructor */ | ||
TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim) | ||
: index_(index), src_(src), input_dim_(input_dim) {} | ||
}; // struct TakeGradExp | ||
|
||
|
||
template<typename IndexExp, | ||
typename SrcExp, | ||
typename DType, | ||
int e1, int e2> | ||
inline TakeGradExp<IndexExp, SrcExp, default_real_t> | ||
take_grad(const Exp<IndexExp, DType, e1> &index, | ||
const Exp<SrcExp, DType, e2> &src, | ||
const index_t input_dim) { | ||
return TakeGradExp<IndexExp, SrcExp, default_real_t>(index.self(), | ||
src.self(), | ||
input_dim); | ||
} | ||
|
||
//---------------------- | ||
// Execution plan | ||
//---------------------- | ||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
struct Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> { | ||
public: | ||
explicit Plan(const TakeGradExp<IndexExp, SrcExp, DType> &e) | ||
: index_(MakePlan(e.index_)), | ||
src_(MakePlan(e.src_)), | ||
batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) { | ||
} | ||
|
||
// now return shape: in * out | ||
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | ||
DType ret = 0.f; | ||
for (index_t i = 0; i < batch_size_; ++i) { | ||
index_t idx = static_cast<index_t>(index_.Eval(0, i)); | ||
if (idx == y) { | ||
ret += static_cast<DType>(src_.Eval(i, x)); | ||
} | ||
} | ||
return ret; | ||
} | ||
|
||
private: | ||
expr::Plan<IndexExp, DType> index_; | ||
expr::Plan<SrcExp, DType> src_; | ||
const index_t batch_size_; | ||
}; // struct Plan | ||
|
||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
inline Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> | ||
MakePlan(const TakeGradExp<IndexExp, SrcExp, DType> &exp) { | ||
return Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType>(exp); | ||
} | ||
|
||
template<int dim, typename IndexExp, typename SrcExp, typename DType> | ||
struct ShapeCheck<dim, TakeGradExp<IndexExp, SrcExp, DType> > { | ||
inline static Shape<dim> | ||
Check(const TakeGradExp<IndexExp, SrcExp, DType> &t) { | ||
CHECK(dim == 2) | ||
<< "TakeGradExp only support 2D output"; | ||
// Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); | ||
Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_); | ||
Shape<dim> ret; | ||
ret[0] = t.input_dim_; | ||
ret[1] = gshape[1]; | ||
return ret; | ||
} | ||
}; // struct ShapeCheck | ||
|
||
template<typename IndexExp, typename SrcExp, typename DType> | ||
struct ExpInfo<TakeGradExp<IndexExp, SrcExp, DType> > { | ||
static const int kDim = 2; | ||
static const int kDevMask = ExpInfo<IndexExp>::kDevMask; | ||
}; | ||
|
||
} // namespace expr | ||
} // namespace mshadow | ||
|
||
#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_ |
Empty file.