Skip to content

Commit

Permalink
Merge pull request apache#58 from tqchen/master
Browse files Browse the repository at this point in the history
Add implicit gemm
  • Loading branch information
tqchen committed Oct 17, 2015
2 parents 56e6153 + aa77e19 commit 9aaa23b
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 1 deletion.
12 changes: 12 additions & 0 deletions guide/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ int main(void) {
}
printf("\n");
}

TensorContainer<cpu, 2> lhs(Shape2(2, 3)), rhs(Shape2(2, 3)), ret(Shape2(2,2));
lhs = 1.0;
rhs = 1.0;
ret = implicit_dot(lhs, rhs.T());
for (index_t i = 0; i < ret.size(0); ++i) {
for (index_t j = 0; j < ret.size(1); ++j) {
printf("%.2f ", ret[i][j]);
}
printf("\n");
}

// shutdown tensor enigne after usage
ShutdownTensorEngine<cpu>();
return 0;
Expand Down
3 changes: 2 additions & 1 deletion mshadow/extension.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright by Contributors
* \file extension.h
* \brief some extension of expressions,
* \brief some extension of expressions,
* used to support something beyond elementwise op
* \author Tianqi Chen, Bing Xu
*/
Expand All @@ -22,4 +22,5 @@
#include "./extension/crop.h"
#include "./extension/mirror.h"
#include "./extension/concat.h"
#include "./extension/implicit_gemm.h"
#endif // MSHADOW_EXTENSION_H_
127 changes: 127 additions & 0 deletions mshadow/extension/implicit_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*!
* Copyright (c) 2014 by Contributors
* \file implicit_gemm.h
* \brief support for implicit GEMM operation
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
#define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_

#include "../extension.h"
#include "../packet-inl.h"

namespace mshadow {
namespace expr {
/*!
* \brief Matrix multiplication.
* \tparam LhsExp type of lhs expression
* \tparam LhsExp type of rhs expression
* \tparam DType the type of elements
*/
template<typename LhsExp, typename RhsExp, typename DType>
struct ImplicitGEMMExp:
public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>,
DType, type::kChainer> {
/*! \brief lhs operand */
const LhsExp &lhs_;
/*! \brief rhs operand */
const RhsExp &rhs_;
/*! \brief internal production size*/
index_t prod_size_;
/*! \brief the shape of this expression */
Shape<2> shape_;
/*! \brief constructor */
ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
: lhs_(lhs), rhs_(rhs) {
Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_);
Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_);
this->shape_ = mshadow::Shape2(slhs[0], srhs[1]);
prod_size_ = slhs[1];
}
};


template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2>
inline ImplicitGEMMExp<LhsExp, RhsExp, DType>
implicit_dot(const Exp<LhsExp, DType, e1> &lhs,
const Exp<RhsExp, DType, e2> &rhs) {
TypeCheckPass<ExpInfo<LhsExp>::kDim == 2 && ExpInfo<RhsExp>::kDim == 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return ImplicitGEMMExp<LhsExp, RhsExp, DType>(lhs.self(), rhs.self());
}

//----------------------
// Execution plan
//----------------------
template<typename LhsExp, typename RhsExp, typename DType>
struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> {
public:
explicit Plan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &e)
: lhs_(MakePlan(e.lhs_)),
rhs_(MakePlan(e.rhs_)),
prod_size_(e.prod_size_),
prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) {
}

MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
typedef packet::Packet<DType> Packet;
Packet sum = Packet::Fill(0);

DType lhs_temp[Packet::kSize], rhs_temp[Packet::kSize];

for (index_t i = 0; i < prod_size_lower_align_; i += packet::Packet<DType>::kSize) {
// unroll
for (index_t j = 0; j < Packet::kSize; ++j) {
lhs_temp[j] = lhs_.Eval(y, i + j);
}
for (index_t j = 0; j < Packet::kSize; ++j) {
rhs_temp[j] = rhs_.Eval(i + j, x);
}
sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp);
}
DType ret_result = sum.Sum();

for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) {
ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x);
}
return ret_result;
}

private:
expr::Plan<LhsExp, DType> lhs_;
expr::Plan<RhsExp, DType> rhs_;
const index_t prod_size_;
const index_t prod_size_lower_align_;
};

template<typename LhsExp, typename RhsExp, typename DType>
inline Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>
MakePlan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &exp) {
return Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>(exp);
}


template<int dim, typename LhsExp, typename RhsExp, typename DType>
struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
inline static Shape<dim>
Check(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &t) {
CHECK(dim == 2)
<< "ImplicitGEMMExp only support 2 dimension";
Shape<dim> shape1 = ShapeCheck<dim, LhsExp>::Check(t.lhs_);
Shape<dim> shape2 = ShapeCheck<dim, RhsExp>::Check(t.rhs_);
CHECK_EQ(shape1[1], shape2[0])
<< "implicit_dot The matrix shape do not match";
return t.shape_;
}
};

template<typename LhsExp, typename RhsExp, typename DType>
struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask;
};

} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_

4 changes: 4 additions & 0 deletions mshadow/packet/plain-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ struct Packet<DType, kPlain> {
MSHADOW_CINLINE static Packet<DType, kPlain> Load(const DType* src) {
return Packet<DType, kPlain>(*src);
}
// load from address
MSHADOW_CINLINE static Packet<DType, kPlain> LoadUnAligned(const DType* src) {
return Packet<DType, kPlain>(*src);
}
// fill it with value s
MSHADOW_CINLINE Packet<DType, kPlain>& operator=(DType s) {
data_ = s;
Expand Down
7 changes: 7 additions & 0 deletions mshadow/packet/sse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ struct Packet<float, kSSE2> {
MSHADOW_CINLINE static Packet<float, kSSE2> Load(const float* src) {
return Packet<float, kSSE2>(_mm_load_ps(src));
}
// load from address
MSHADOW_CINLINE static Packet<float, kSSE2> LoadUnAligned(const float* src) {
return Packet<float, kSSE2>(_mm_loadu_ps(src));
}
// fill it with value s
MSHADOW_CINLINE Packet<float, kSSE2>& operator=(float s) {
data_ = _mm_set1_ps(s);
Expand Down Expand Up @@ -73,6 +77,9 @@ struct Packet<double, kSSE2> {
MSHADOW_CINLINE static Packet<double, kSSE2> Load(const double* src) {
return Packet<double, kSSE2>(_mm_load_pd(src));
}
MSHADOW_CINLINE static Packet<double, kSSE2> LoadUnAligned(const double* src) {
return Packet<double, kSSE2>(_mm_loadu_pd(src));
}
// fill it with value s
MSHADOW_CINLINE Packet<double, kSSE2>& operator=(double s) {
data_ = _mm_set1_pd(s);
Expand Down

0 comments on commit 9aaa23b

Please sign in to comment.