You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
step2: flash-attn forward algorithm with tiling (double-loop):
the outer loop runs through $i := 1 \rightarrow N_q$ for each block of $Q_i$ to compute $O_i$, where $N_q = \lceil\frac{N}{B_q}\rceil$
$$
\text{in one i-th outer iteration}:
\begin{cases}
\begin{align}
&\text{load}\space Q_i \in \mathbb{R}^{B_q\times d}\space \text{from HBM to SRAM}\notag\\
&\text{initialize}\space \tilde{O_{i}}^{(0)} = 0_{ B_q\times d },\space l_i^{(0)} = 0_{B_q} \in \mathbb{R}^{B_q},\space m_i^{(0)} = -\infty_{B_q} \in \mathbb{R}^{B_q} \notag\\
&\text{loop over}\space j := 1 \rightarrow N_k\space \text{for each j-th inner iteration} \notag\\
&\text{compute}\space O_i = \mathrm{diag}(l_{i}^{(N_k)})^{-1} \tilde{O_i}^{(N_k)}\in \mathbb{R}^{B_q\times d}\space \text{and write it to HBM to return as output} \notag\\
&\text{compute}\space \mathrm{LSE_i} = m_i^{(N_k)} + \log(l_i^{(N_k)})\in \mathbb{R}^{B_q} \space \text{and write it to HBM to save for backward} \notag
\end{align}
\end{cases}
$$$$
\begin{align}
&\text{where}\quad \text{LSE}( \mathbf{x}) := \log\left(\sum\limits_{i=1}^n \exp(x_i)\right) = \max( \mathbf x) + \text{LSE}( \mathbf{x}-\max( \mathbf x)),\space \mathbf x \in \mathbb{R}^{n},\\
&\text{and}\space \tilde{O_i} \space\text{is the un-normalized} \space O_i, \space\text{i.e.}\space O_i = \mathrm{diag}(l_{i})^{-1}\tilde{O_i}
\end{align}
$$
in which each inner loop goes across $j := 1 \rightarrow N_k$ for each block of $K_j,V_j$ to update $\tilde{O_i}^{(j)}, l_i^{(j)}, m_i^{(j)}$, where $N_k = \lceil\frac{N}{B_k}\rceil$
so we can avoid massive matrix computing like $A_{i:}A_{i:}^{\mathrm T} \in \mathbb{R}^{B_k\times B_k}$
step2. flash-attn backward algorithm with recomputation (double-loop):
the outer loop runs through $j := 1 \rightarrow N_k$ for each block of $K_j, V_j$ to compute $dK_j, dV_j$, where $N_k = \lceil\frac{N}{B_k}\rceil$
$$
\text{in one j-th outer iteration}:
\begin{cases}
\begin{align}
&\text{load}\space K_j, V_j \in \mathbb{R}^{B_k\times d}\space \text{from HBM to SRAM, and initialize}\space dK_j^{(0)}, dV_j^{(0)} = (0)_{B_c\times d} \in \mathbb{R}^{B_k\times d} \notag \\
&\text{loop over}\space i := 1 \rightarrow N_q\space \text{for each i-th inner iteration} \notag \\
&\text{write}\space dK_j = dK_j^{(N_q)}, dV_j = dV_j^{(N_q)} \space \text{back to HBM to return as output} \notag
\end{align}
\end{cases}
$$
in which each inner loop goes across $i := 1 \rightarrow N_q$ for each block of $Q_i, O_i, dO_i$ to update $dQ_i, dK_j^{(i)}, dV_j^{(i)}$, where $N_q = \lceil\frac{N}{B_q}\rceil$