From de24af848bf9d8fadac1cb1fdbe5a68cee5d1791 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 17 Oct 2015 13:07:03 -0700 Subject: [PATCH 1/2] Add implicit gemm --- guide/basic.cpp | 12 +++ mshadow/extension.h | 3 +- mshadow/extension/implicit_gemm.h | 127 ++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 mshadow/extension/implicit_gemm.h diff --git a/guide/basic.cpp b/guide/basic.cpp index 04c3185608b6..24437bd290bd 100644 --- a/guide/basic.cpp +++ b/guide/basic.cpp @@ -36,6 +36,18 @@ int main(void) { } printf("\n"); } + + TensorContainer lhs(Shape2(2, 3)), rhs(Shape2(2, 3)), ret(Shape2(2,2)); + lhs = 1.0; + rhs = 1.0; + ret = implicit_dot(lhs, rhs.T()); + for (index_t i = 0; i < ret.size(0); ++i) { + for (index_t j = 0; j < ret.size(1); ++j) { + printf("%.2f ", ret[i][j]); + } + printf("\n"); + } + // shutdown tensor enigne after usage ShutdownTensorEngine(); return 0; diff --git a/mshadow/extension.h b/mshadow/extension.h index fab96e4e092f..d16526c8db65 100644 --- a/mshadow/extension.h +++ b/mshadow/extension.h @@ -1,7 +1,7 @@ /*! * Copyright by Contributors * \file extension.h - * \brief some extension of expressions, + * \brief some extension of expressions, * used to support something beyond elementwise op * \author Tianqi Chen, Bing Xu */ @@ -22,4 +22,5 @@ #include "./extension/crop.h" #include "./extension/mirror.h" #include "./extension/concat.h" +#include "./extension/implicit_gemm.h" #endif // MSHADOW_EXTENSION_H_ diff --git a/mshadow/extension/implicit_gemm.h b/mshadow/extension/implicit_gemm.h new file mode 100644 index 000000000000..8b32d93b2c68 --- /dev/null +++ b/mshadow/extension/implicit_gemm.h @@ -0,0 +1,127 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file implicit_gemm.h + * \brief support for implicit GEMM operation + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ +#define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ + +#include "../extension.h" +#include "../packet-inl.h" + +namespace mshadow { +namespace expr { +/*! + * \brief Matrix multiplication. + * \tparam LhsExp type of lhs expression + * \tparam LhsExp type of rhs expression + * \tparam DType the type of elements + */ +template +struct ImplicitGEMMExp: + public Exp, + DType, type::kChainer> { + /*! \brief lhs operand */ + const LhsExp &lhs_; + /*! \brief rhs operand */ + const RhsExp &rhs_; + /*! \brief internal production size*/ + index_t prod_size_; + /*! \brief the shape of this expression */ + Shape<2> shape_; + /*! \brief constructor */ + ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs) + : lhs_(lhs), rhs_(rhs) { + Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_); + Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_); + this->shape_ = mshadow::Shape2(slhs[0], srhs[1]); + prod_size_ = slhs[1]; + } +}; + + +template +inline ImplicitGEMMExp +implicit_dot(const Exp &lhs, + const Exp &rhs) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return ImplicitGEMMExp(lhs.self(), rhs.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ImplicitGEMMExp &e) + : lhs_(MakePlan(e.lhs_)), + rhs_(MakePlan(e.rhs_)), + prod_size_(e.prod_size_), + prod_size_lower_align_(packet::LowerAlign(e.prod_size_)) { + } + + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + typedef packet::Packet Packet; + Packet sum = Packet::Fill(0); + + DType lhs_temp[Packet::kSize], rhs_temp[Packet::kSize]; + + for (index_t i = 0; i < prod_size_lower_align_; i += packet::Packet::kSize) { + // unroll + for (index_t j = 0; j < Packet::kSize; ++j) { + lhs_temp[j] = lhs_.Eval(y, i + j); + } + for (index_t j = 0; j < Packet::kSize; ++j) { + rhs_temp[j] = rhs_.Eval(i + j, x); + } + sum = sum + Packet::Load(lhs_temp) * Packet::Load(rhs_temp); + } + DType ret_result = sum.Sum(); + + for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) { + ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x); + } + return ret_result; + } + + private: + expr::Plan lhs_; + expr::Plan rhs_; + const index_t prod_size_; + const index_t prod_size_lower_align_; +}; + +template +inline Plan, DType> +MakePlan(const ImplicitGEMMExp &exp) { + return Plan, DType>(exp); +} + + +template +struct ShapeCheck > { + inline static Shape + Check(const ImplicitGEMMExp &t) { + CHECK(dim == 2) + << "ImplicitGEMMExp only support 2 dimension"; + Shape shape1 = ShapeCheck::Check(t.lhs_); + Shape shape2 = ShapeCheck::Check(t.rhs_); + CHECK_EQ(shape1[1], shape2[0]) + << "implicit_dot The matrix shape do not match"; + return t.shape_; + } +}; + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ + From aa77e19d6e17b681034a068e3af988496e553bf9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 17 Oct 2015 16:07:29 -0700 Subject: [PATCH 2/2] change to unaligned load --- mshadow/extension/implicit_gemm.h | 2 +- mshadow/packet/plain-inl.h | 4 ++++ mshadow/packet/sse-inl.h | 7 +++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mshadow/extension/implicit_gemm.h b/mshadow/extension/implicit_gemm.h index 8b32d93b2c68..64a7b3e47c04 100644 --- a/mshadow/extension/implicit_gemm.h +++ b/mshadow/extension/implicit_gemm.h @@ -77,7 +77,7 @@ struct Plan, DType> { for (index_t j = 0; j < Packet::kSize; ++j) { rhs_temp[j] = rhs_.Eval(i + j, x); } - sum = sum + Packet::Load(lhs_temp) * Packet::Load(rhs_temp); + sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp); } DType ret_result = sum.Sum(); diff --git a/mshadow/packet/plain-inl.h b/mshadow/packet/plain-inl.h index b28671f59d05..ab2453f9c54d 100644 --- a/mshadow/packet/plain-inl.h +++ b/mshadow/packet/plain-inl.h @@ -30,6 +30,10 @@ struct Packet { MSHADOW_CINLINE static Packet Load(const DType* src) { return Packet(*src); } + // load from address + MSHADOW_CINLINE static Packet LoadUnAligned(const DType* src) { + return Packet(*src); + } // fill it with value s MSHADOW_CINLINE Packet& operator=(DType s) { data_ = s; diff --git a/mshadow/packet/sse-inl.h b/mshadow/packet/sse-inl.h index cdf24c5e6edd..295fb289127c 100644 --- a/mshadow/packet/sse-inl.h +++ b/mshadow/packet/sse-inl.h @@ -32,6 +32,10 @@ struct Packet { MSHADOW_CINLINE static Packet Load(const float* src) { return Packet(_mm_load_ps(src)); } + // load from address + MSHADOW_CINLINE static Packet LoadUnAligned(const float* src) { + return Packet(_mm_loadu_ps(src)); + } // fill it with value s MSHADOW_CINLINE Packet& operator=(float s) { data_ = _mm_set1_ps(s); @@ -73,6 +77,9 @@ struct Packet { MSHADOW_CINLINE static Packet Load(const double* src) { return Packet(_mm_load_pd(src)); } + MSHADOW_CINLINE static Packet LoadUnAligned(const double* src) { + return Packet(_mm_loadu_pd(src)); + } // fill it with value s MSHADOW_CINLINE Packet& operator=(double s) { data_ = _mm_set1_pd(s);