To clone this repo,
git clone --recursive https://github.com/jundaf2/eigenMHA
cd eigenMHA
git clone https://gitlab.com/libeigen/eigen # clone eigen if necessary
In this repo, we use Eigen3 to implement the forward and backward of Multi-head Attention in Transformer models. Basically, this repo has two branches -- torch
and cudnn
.
- a pytorch MHA in
mha.py
that illustrates the MHA module we implement - an eigen MHA in
mha.cc
in both branches (with sources in./src/eigenDNN.cpp
and headers in./inlcude/eigenDNN.h
) - a libtorch MHA in the
torch
branch as a comparison to the eigenMHA - a cudnn MHA in the
cudnn
branch as a comparison to the eigenMHA
git checkout torch
In this branch, the eigenDNN is compared with the CPU libtorch. To make and run the project, first install LibTorch for necessary verification, see https://github.com/jundaf2/dnn-test-framework [nnTest mainly focuses on providing a testing framework to train and inference Deep Neural Networks using YOUR OWN LIBRARY]. And then,
mkdir build && cd build
cmake ..
make -j4
./mha
git checkout cudnn
In this branch, the eigenDNN is compared with the Multi-head Attention APIs provided by cuDNN V8 (cudnn_samples_v8/multiHeadAttention
).
To install cuDNN, see https://developer.nvidia.com/rdp/cudnn-download and https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar . After copying the corresponding libraries and headers to the correct location,
mkdir build && cd build
cmake ..
make -j4
./mha
To be more specific, this eigenDNN does what the cuDNN does in the following APIs for MHA operations.
- cudnnCreateAttnDescriptor()
- cudnnSetAttnDescriptor()
- cudnnGetAttnDescriptor()
- cudnnSetAttnDescriptor()
- cudnnDestroyAttnDescriptor()
- cudnnGetMultiHeadAttnBuffers()
- cudnnGetMultiHeadAttnWeights()
- cudnnMultiHeadAttnForward()
- cudnnMultiHeadAttnBackwardData()
- cudnnMultiHeadAttnBackwardWeights()
For more details of the Attention APIs in cuDNN v8, see this 中文CSDN链接.
- Q, K, V input embeddings
$$ \mathbf{Q}{in} \quad \mathbf{K}{in} \quad \mathbf{V}_{in} $$
- Weights and bias for the linear layer of Q K V and O.
$$ \mathbf{W}{Q} \quad \mathbf{b}{Q} $$
$$ \mathbf{W}{K} \quad \mathbf{b}{K} $$
$$ \mathbf{W}{V} \quad \mathbf{b}{V} $$
$$ \mathbf{W}{O} \quad \mathbf{b}{O} $$
- Intermediate variables
- Output and target
$$ \mathbf{O}{out}\quad\mathbf{O}{target} $$
The equations of MHA forward pass are as follows,
$$ \mathbf{Q} = \mathbf{Q}{in}*\mathbf{W}{Q}+\mathbf{b}_{Q} $$
$$ \mathbf{K} = \mathbf{K}{in}*\mathbf{W}{K}+\mathbf{b}_{K} $$
$$ \mathbf{V} = \mathbf{V}{in}*\mathbf{W}{V}+\mathbf{b}_{V} $$
$$ \mathbf{O}{out} = \mathbf{O}*\mathbf{W}{O}+\mathbf{b}_{O} $$
$$ loss = MSELoss(\mathbf{O}{out},\mathbf{O}{target}) $$
MSELoss will also gives
$$ \mathbf{grad\O}{out} $$
, the gradient of
- Gradients for output (from LayerNorm)
$$ \mathbf{grad\O}{out} $$
- Gradients for the intermediate variables
- Gradients for the forward input
$$ \mathbf{grad\Q}{in} \quad \mathbf{grad\K}{in} \quad \mathbf{grad\V}{in} $$
- Gradients of the weights and biases
$$ \mathbf{grad\W}{Q} \quad \mathbf{grad\b}{Q} $$
$$ \mathbf{grad\W}{K} \quad \mathbf{grad\b}{K} $$
$$ \mathbf{grad\W}{V} \quad \mathbf{grad\b}{V} $$
$$ \mathbf{grad\W}{O} \quad \mathbf{grad\b}{O} $$
The equations of MHA backward pass are as follows,
$$ \mathbf{grad\O} = \mathbf{grad\O}{out}*\mathbf{W}{O} $$
$$ \mathbf{grad\W}{O} = \mathbf{grad\O}{out}^T*\mathbf{O} $$
$$ \mathbf{grad\b}{O} = colsum(\mathbf{grad\O}{out}) $$
$$ \mathbf{grad\Q}{in} = \mathbf{grad\Q}*\mathbf{W}{Q}^T $$
$$ \mathbf{grad\W}{Q} = \mathbf{Q}_{in}^T*\mathbf{grad\_Q} $$
$$ \mathbf{grad\b}{Q} = colsum(\mathbf{grad\_Q}) $$
$$ \mathbf{grad\K}{in} = \mathbf{grad\K}*\mathbf{W}{K}^T $$
$$ \mathbf{grad\W}{K} = \mathbf{K}_{in}^T*\mathbf{grad\_K} $$
$$ \mathbf{grad\b}{K} = colsum(\mathbf{grad\_K}) $$
$$ \mathbf{grad\V}{in} = \mathbf{grad\V}*\mathbf{W}{V}^T $$
$$ \mathbf{grad\W}{V} = \mathbf{V}_{in}^T*\mathbf{grad\_V} $$
$$ \mathbf{grad\b}{V} = colsum(\mathbf{grad\_V}) $$
Loss function, as the origin of DL system, is a basic component inside a DL system.
MSE Loss.eidnnStatus_t eidnnMSELoss(
eidnnHandle_t handle,
const Tensor<float, 3> &output,
const Tensor<float, 3> &target,
Tensor<float, 0> &loss,
Tensor<float, 3> &d_loss);
cuDNN has no specific APIs for linear layer.
In eigenDNN, we have
eidnnStatus_t eidnnLinearForward(eidnnHandle_t handle,
const Tensor<float, 3>& x, // data
const Tensor<float, 2>& w, // weight
const Tensor<float, 1>& bias, // bias
Tensor<float, 3>& y);
eidnnStatus_t eidnnLinearBackward(eidnnHandle_t handle,
const Tensor<float, 3>& dy,
const Tensor<float, 3>& x,
const Tensor<float, 2>& w,
Tensor<float, 3>& dx, // gradient of input data
Tensor<float, 2>& dw, // accumulated gradient of weight
Tensor<float, 1>& dbias // accumulated gradient of bias
);
, where
cuDNN has no specific APIs for matrix-multiply operation.
In eigenDNN, we have
eidnnStatus_t eidnnStridedBatchedGemmForward(
eidnnHandle_t handle,
float alpha,
float beta,
bool trans_A, // Op_a
bool trans_B, // Op_b
bool trans_C, // Op_c
const Tensor<float, 4> &A,
const Tensor<float, 4> &B,
Tensor<float, 4> &C);
eidnnStatus_t eidnnStridedBatchedGemmBackward(
eidnnHandle_t handle,
float alpha,
float beta,
bool trans_A, // Op_a
bool trans_B, // Op_b
bool trans_C, // Op_c
const Tensor<float, 4> &A, // A
const Tensor<float, 4> &B, // B
const Tensor<float, 4> &d_C, // gradient of C
Tensor<float, 4> &d_A, // gradient of A
Tensor<float, 4> &d_B // gradient of B
);
cuDNN has the following APIs for softmax operation.
In eigenDNN, we have
eidnnStatus_t eidnnSoftmaxForward(eidnnHandle_t handle,
eidnnSoftmaxAlgorithm_t algo,
eidnnSoftmaxMode_t mode,
const Tensor<float, 4>& x,
Tensor<float, 4>& y);
eidnnStatus_t eidnnSoftmaxBackward(eidnnHandle_t handle,
eidnnSoftmaxAlgorithm_t algo,
eidnnSoftmaxMode_t mode,
const Tensor<float, 4>& y,
const Tensor<float, 4>& dy,
Tensor<float, 4>& dx);
cuDNN has the following APIs for dropout operation.
- cudnnCreateDropoutDescriptor()
- cudnnDestroyDropoutDescriptor()
- cudnnDropoutGetStatesSize()
- cudnnDropoutGetReserveSpaceSize()
- cudnnDropoutForward()
- cudnnGetDropoutDescriptor()
- cudnnRestoreDropoutDescriptor()
- cudnnSetDropoutDescriptor()
- cudnnDropoutBackward()
In eigenDNN, we have
// dropout rate,
// pointer to memory space of states (allocated by forward pass),
// size of memory space in bytes (calculated by forward pass),
// random seed
using eidnnDropoutDescriptor_t = std::tuple<float, void*, size_t, unsigned long long>;
eidnnStatus_t eidnnDropoutForward(
eidnnHandle_t handle,
eidnnDropoutDescriptor_t &dropoutDesc,
const Tensor<float, 4> &x, // input data
Tensor<float, 4> &y // input data after dropout
);
eidnnStatus_t eidnnDropoutBackward(
eidnnHandle_t handle,
const eidnnDropoutDescriptor_t dropoutDesc,
const Tensor<float, 4> &dy, // gradient of dropout output data
Tensor<float, 4> &dx // gradient of dropout input data
);