Skip to content

Commit

Permalink
Merge pull request apache#73 from antinucleon/master
Browse files Browse the repository at this point in the history
Take and Take Grad
  • Loading branch information
piiswrong committed Nov 20, 2015
2 parents f8ed18e + 9c37c2b commit bc5cb99
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 0 deletions.
28 changes: 28 additions & 0 deletions guide/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,34 @@ int main(void) {
printf("\n");
}
printf("\n");
TensorContainer<cpu, 1> idx(Shape1(3));
idx[0] = 8;
idx[1] = 0;
idx[1] = 1;

TensorContainer<cpu, 2> weight(Shape2(10, 5));
TensorContainer<cpu, 2> embed(Shape2(3, 5));

for (index_t i = 0; i < weight.size(0); ++i) {
for (index_t j = 0; j < weight.size(1); ++j) {
weight[i][j] = i;
}
}
embed = take(idx, weight);
for (index_t i = 0; i < embed.size(0); ++i) {
for (index_t j = 0; j < embed.size(1); ++j) {
printf("%.2f ", embed[i][j]);
}
printf("\n");
}
printf("\n\n");
weight = take_grad(idx, embed, 10);
for (index_t i = 0; i < weight.size(0); ++i) {
for (index_t j = 0; j < weight.size(1); ++j) {
printf("%.2f ", weight[i][j]);
}
printf("\n");
}

// shutdown tensor enigne after usage
ShutdownTensorEngine<cpu>();
Expand Down
2 changes: 2 additions & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@
#include "./extension/choose.h"
#include "./extension/one_hot.h"
#include "./extension/slice.h"
#include "./extension/take.h"
#include "./extension/take_grad.h"
#endif // MSHADOW_EXTENSION_H_
Empty file modified mshadow/extension/pack_col2patch.h
100755 → 100644
Empty file.
99 changes: 99 additions & 0 deletions mshadow/extension/take.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*!
* Copyright (c) 2015 by Contributors
* \file take.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_TAKE_H_
#define MSHADOW_EXTENSION_TAKE_H_

#include "../extension.h"

namespace mshadow {
namespace expr {

/*! \brief Take a column from a matrix
* \tparam IndexExp type of index expression
* \tparam SrcExp type of src expression
* \tparam DType data type
*/
template<typename IndexExp, typename SrcExp, typename DType>
struct TakeExp: public Exp<TakeExp<IndexExp, SrcExp, DType>,
DType, type::kChainer> {
/*! \brief index oprand */
const IndexExp &index_;
/*! \brief embediing oprand */
const SrcExp &src_;
/*! constructor */
TakeExp(const IndexExp &index, const SrcExp &src)
: index_(index), src_(src) {}
}; // struct TakeExp



template<typename IndexExp,
typename SrcExp,
typename DType,
int e1, int e2>
inline TakeExp<IndexExp, SrcExp, default_real_t>
take(const Exp<IndexExp, DType, e1> &index,
const Exp<SrcExp, DType, e2> &src) {
return TakeExp<IndexExp, SrcExp, default_real_t>(index.self(), src.self());
}


//----------------------
// Execution plan
//----------------------

template<typename IndexExp, typename SrcExp, typename DType>
struct Plan<TakeExp<IndexExp, SrcExp, DType>, DType> {
public:
explicit Plan(const TakeExp<IndexExp, SrcExp, DType> &e)
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) {
}

// TODO(xx): discuss W shape: in * out or out * in
// Now I use in * out
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, y));
return static_cast<DType>(src_.Eval(idx, x));
}

private:
expr::Plan<IndexExp, DType> index_;
expr::Plan<SrcExp, DType> src_;
}; // struct Plan

template<typename IndexExp, typename SrcExp, typename DType>
inline Plan<TakeExp<IndexExp, SrcExp, DType>, DType>
MakePlan(const TakeExp<IndexExp, SrcExp, DType> &exp) {
return Plan<TakeExp<IndexExp, SrcExp, DType>, DType>(exp);
}

template<int dim, typename IndexExp, typename SrcExp, typename DType>
struct ShapeCheck<dim, TakeExp<IndexExp, SrcExp, DType> > {
inline static Shape<dim>
Check(const TakeExp<IndexExp, SrcExp, DType> &t) {
CHECK(dim == 2)
<< "TakeExp only support 2D output";
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<dim> ret;
ret[0] = dshape[0];
ret[1] = wshape[1];
return ret;
}
};


template<typename IndexExp, typename SrcExp, typename DType>
struct ExpInfo<TakeExp<IndexExp, SrcExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};

} // namespace expr
} // namespace mshadow

#endif // MSHADOW_EXTENSION_TAKE_H_
111 changes: 111 additions & 0 deletions mshadow/extension/take_grad.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*!
* Copyright (c) 2015 by Contributors
* \file take_grad.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_
#define MSHADOW_EXTENSION_TAKE_GRAD_H_

#include "../extension.h"

namespace mshadow {
namespace expr {

/*! \brief Calculate embedding gradient
* \tparam IndexExp type of index expression
* \tparam SrcExp type of src expression
* \tparam DType data type
*/

template<typename IndexExp, typename SrcExp, typename DType>
struct TakeGradExp : public Exp<TakeGradExp<IndexExp, SrcExp, DType>,
DType, type::kChainer> {
/*! \brief index oprand */
const IndexExp &index_;
/*! \brief out gradient oprand */
const SrcExp &src_;
/*! \brief batch size */
const index_t input_dim_;
/*! \brief constructor */
TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim)
: index_(index), src_(src), input_dim_(input_dim) {}
}; // struct TakeGradExp


template<typename IndexExp,
typename SrcExp,
typename DType,
int e1, int e2>
inline TakeGradExp<IndexExp, SrcExp, default_real_t>
take_grad(const Exp<IndexExp, DType, e1> &index,
const Exp<SrcExp, DType, e2> &src,
const index_t input_dim) {
return TakeGradExp<IndexExp, SrcExp, default_real_t>(index.self(),
src.self(),
input_dim);
}

//----------------------
// Execution plan
//----------------------

template<typename IndexExp, typename SrcExp, typename DType>
struct Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> {
public:
explicit Plan(const TakeGradExp<IndexExp, SrcExp, DType> &e)
: index_(MakePlan(e.index_)),
src_(MakePlan(e.src_)),
batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) {
}

// now return shape: in * out
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
DType ret = 0.f;
for (index_t i = 0; i < batch_size_; ++i) {
index_t idx = static_cast<index_t>(index_.Eval(0, i));
if (idx == y) {
ret += static_cast<DType>(src_.Eval(i, x));
}
}
return ret;
}

private:
expr::Plan<IndexExp, DType> index_;
expr::Plan<SrcExp, DType> src_;
const index_t batch_size_;
}; // struct Plan


template<typename IndexExp, typename SrcExp, typename DType>
inline Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType>
MakePlan(const TakeGradExp<IndexExp, SrcExp, DType> &exp) {
return Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType>(exp);
}

template<int dim, typename IndexExp, typename SrcExp, typename DType>
struct ShapeCheck<dim, TakeGradExp<IndexExp, SrcExp, DType> > {
inline static Shape<dim>
Check(const TakeGradExp<IndexExp, SrcExp, DType> &t) {
CHECK(dim == 2)
<< "TakeGradExp only support 2D output";
// Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<dim> ret;
ret[0] = t.input_dim_;
ret[1] = gshape[1];
return ret;
}
}; // struct ShapeCheck

template<typename IndexExp, typename SrcExp, typename DType>
struct ExpInfo<TakeGradExp<IndexExp, SrcExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};

} // namespace expr
} // namespace mshadow

#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_
Empty file modified mshadow/extension/unpack_patch2col.h
100755 → 100644
Empty file.

0 comments on commit bc5cb99

Please sign in to comment.