Skip to content

Commit

Permalink
Add choose operation to choose row elements from a matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Oct 21, 2015
1 parent 2f9e617 commit 5d34606
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
19 changes: 19 additions & 0 deletions guide/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,32 @@ int main(void) {
lhs = 1.0;
rhs = 1.0;
ret = implicit_dot(lhs, rhs.T());
int cnt = 0;
for (index_t i = 0; i < ret.size(0); ++i) {
for (index_t j = 0; j < ret.size(1); ++j) {
printf("%.2f ", ret[i][j]);
}
printf("\n");
}

printf("\n");

for (index_t i = 0; i < lhs.size(0); ++i) {
for (index_t j = 0; j < lhs.size(1); ++j) {
lhs[i][j] = cnt++;
printf("%.2f ", lhs[i][j]);
}
printf("\n");
}
printf("\n");
TensorContainer<cpu, 1> index(Shape1(2)), choosed(Shape1(2));
index[0] = 1; index[1] = 2;
choosed = mat_choose_row_element(lhs, index);
for (index_t i = 0; i < choosed.size(0); ++i) {
printf("%.2f ", choosed[i]);
}
printf("\n ");

// shutdown tensor enigne after usage
ShutdownTensorEngine<cpu>();
return 0;
Expand Down
1 change: 1 addition & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
#include "./extension/mirror.h"
#include "./extension/concat.h"
#include "./extension/implicit_gemm.h"
#include "./extension/choose.h"
#endif // MSHADOW_EXTENSION_H_
93 changes: 93 additions & 0 deletions mshadow/extension/choose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*!
* Copyright (c) 2014 by Contributors
* \file choose.h
* \brief support for implicit array selection operation
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_CHOOSE_H_
#define MSHADOW_EXTENSION_CHOOSE_H_

#include "../extension.h"
#include "../packet-inl.h"

namespace mshadow {
namespace expr {
/*!
* \brief Make a choice of index in the lowest changing dimension.
* \tparam SrcExp type of lhs expression
* \tparam IndexExp type of index expression
* \tparam DType the type of elements
*/
template<typename SrcExp, typename IndexExp, typename DType>
struct MatChooseRowElementExp:
public Exp<MatChooseRowElementExp<SrcExp, IndexExp, DType>,
DType, type::kChainer> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief index operand */
const IndexExp &index_;
/*! \brief constructor */
MatChooseRowElementExp(const SrcExp &src, const IndexExp &index)
: src_(src), index_(index) {}
};

template<typename SrcExp, typename IndexExp,
typename DType, typename IDType, int e1, int e2>
inline MatChooseRowElementExp<SrcExp, IndexExp, DType>
mat_choose_row_element(const Exp<SrcExp, DType, e1> &src,
const Exp<IndexExp, IDType, e2> &index) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<IndexExp>::kDim == 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return MatChooseRowElementExp<SrcExp, IndexExp, DType>(src.self(), index.self());
}

//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename IndexExp, typename DType>
struct Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType> {
public:
explicit Plan(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &e)
: src_(MakePlan(e.src_)),
index_(MakePlan(e.index_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, x));
return src_.Eval(x, idx);
}

private:
expr::Plan<SrcExp, DType> src_;
expr::Plan<IndexExp, DType> index_;
};

template<typename SrcExp, typename IndexExp, typename DType>
inline Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType>
MakePlan(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &exp) {
return Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType>(exp);
}

template<int dim, typename SrcExp, typename IndexExp, typename DType>
struct ShapeCheck<dim, MatChooseRowElementExp<SrcExp, IndexExp, DType> > {
inline static Shape<dim>
Check(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &t) {
CHECK(dim == 1)
<< "MatChooseRowElementExp only support 1 dimension output";
Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<dim> shape2 = ShapeCheck<dim, IndexExp>::Check(t.index_);
CHECK_EQ(shape1[0], shape2[0])
<< "mat_choose_row_element index length and number of rows in matrix";
return shape2;
}
};

template<typename SrcExp, typename IndexExp, typename DType>
struct ExpInfo<MatChooseRowElementExp<SrcExp, IndexExp, DType> > {
static const int kDim = 1;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask & ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CHOOSE_H_


0 comments on commit 5d34606

Please sign in to comment.