Skip to content

Commit

Permalink
[Doc][Polish] gemm optimize by 2d thread tile (#56)
Browse files Browse the repository at this point in the history
Co-authored-by: liu jiawei <[email protected]>
  • Loading branch information
muyuuuu and liu jiawei authored Dec 3, 2024
1 parent 2d37c86 commit a6b8d59
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
38 changes: 22 additions & 16 deletions docs/11_gemm_optimize/01_tiled2d/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@

在介绍二维 Thread Tile 之前,我们先来回顾一下一维 Thread Tile 的优化方法。在初级系列中,我们使用了一维线程块来优化矩阵乘法的性能,我们将矩阵乘法的计算任务分配给了一维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。

我们在每个线程中计算了一维的矩阵块。想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。
还记得一维 Thread Tile 中的例子吗?如果输入的 A 和 B 都是 7x7 的矩阵:

1. 如果我们一次读取 1 行 A 和 1 列 B,当每一个线程只计算一个结果的时候,我们需要从 A 中读取 7 个数据,从 B 中读取 7 个数据,从 C 中读取 1 个数据,然后写 1 次 C。这样的话,每个线程需要读取 15 个数据,写 1 次数据。计算 16 个结果需要 16 个线程,共 16x16 = 256 次 IO。
2. 如果我们一次读取 4 行 A 和 1 列 B,那么每一个线程计算 4 个结果,此时需要从 A 中读取 4x7 个数据,从 B 中读取 7 个数据,从 C 中读取 4 个数据,然后写 4 次 C。计算 16 个结果需要 4 个线程,共 4x43 = 172 次 IO。
3. 如果我们一次读取 4 行 A 和 4 列 B,那么每一个线程计算 16 个结果,此时需要从 A 中读取 4x7 个数据,从 B 中读取 4x7 个数据,从 C 中读取 16 个数据,然后写 16 次 C。计算 16 个结果一共需要 1 个线程,共 1x88 = 88 次 IO。

上述的 `2` 就是一维 Thread Tile 优化,上述的 `3` 就是 二维 Thread Tile 优化,计算结果不变的同时,减少 IO 次数,提升算法的执行时间。所以想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。

## 2. 二维 Thread Tile

### 2.1 优化思路

本文的主要优化思路就是让每个线程计算一个 8\*8 的网格。下面我们来看一下这个 Kernel 的主题思路图:

![picture 1](images/9047246849f79b5117961c15e1a3a340a44ab003566140ecc00600058c70a9e2.png)
![picture 1](images/9047246849f79b5117961c15e1a3a340a44ab003566140ecc00600058c70a9e2.png)

首先在内核的第一阶段, 所有线程协同工作, 从全局内存中加载矩阵 A 和矩阵 B 到共享内存中。

Expand Down Expand Up @@ -74,9 +80,9 @@ float thread_results[TM * TN] = {0.0};
float reg_m[TM] = {0.0};
float reg_n[TN] = {0.0};
A += c_row * BM * K;
B += c_col * BN;
C += c_row * BM * N + c_col * BN;
A += c_row * BM * K;
B += c_col * BN;
C += c_row * BM * N + c_col * BN;
// 外层循环
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK)
Expand Down Expand Up @@ -107,7 +113,7 @@ B += BK * N;

下图可以更好的帮助我们理解上面的代码:

![picture 2](images/f507ad687528e8bbb14a85c1fa3016cce50be55b5670ebc425c549cc5c5bd5a6.png)
![picture 2](images/f507ad687528e8bbb14a85c1fa3016cce50be55b5670ebc425c549cc5c5bd5a6.png)

图中画出了矩阵 A 加载共享内存的过程。在每一步中,每个线程负责加载一个元素到共享内存中。这个元素的索引是 `inner_row_A``inner_col_A` 。for 循环中的 `load_offset` 递增的步长是 `stride_A` 。在图中就是向下移动了 `stride_A` 个元素。

Expand All @@ -130,7 +136,7 @@ for (uint dot_idx = 0; dot_idx < BK; ++dot_idx)
{
for (uint reg_idx_n = 0; reg_idx_n < TN; ++reg_idx_n)
{
thread_results[reg_idx_m * TN + reg_idx_n] +=
thread_results[reg_idx_m * TN + reg_idx_n] +=
reg_m[reg_idx_m] * reg_n[reg_idx_n];
}
}
Expand Down Expand Up @@ -158,7 +164,7 @@ for (uint reg_idx_m = 0; reg_idx_m < TM; ++reg_idx_m)
{
for (uint reg_idx_n = 0; reg_idx_n < TN; ++reg_idx_n)
{
C[(thread_row * TM + reg_idx_m) * N + thread_col * TN + reg_idx_n] =
C[(thread_row * TM + reg_idx_m) * N + thread_col * TN + reg_idx_n] =
thread_results[reg_idx_m * TN + reg_idx_n];
}
}
Expand All @@ -174,20 +180,20 @@ nvcc -o sgemm_tiled2d sgemm_tiled2d.cu
## 3. 性能测试

我们将上该内核的性能和之前的内核进行比较,我们分别计算 256x256、512x512、1024x1024、2048x2048 (Matrix 1、Matrix 2、Matrix 3、Matrix 4、Matrix 5)的矩阵乘法的性能 (us)。在 1080Ti 上运行,结果如下:


| Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 |
| --------- | -------- | -------- | -------- | -------- |
| Naive | 95.5152 | 724.396 | 28424 | 228681 |
| 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 |
| 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 |
| 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 |

| Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 |
| ---------------- | -------- | -------- | -------- | -------- |
| Naive | 95.5152 | 724.396 | 28424 | 228681 |
| 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 |
| 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 |
| 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 |

## 4. 总结

本文我们介绍了二维 Thread Tile 并行优化方法。我们将矩阵乘法的计算任务分配给了二维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。

## Reference
## Reference

1. https://siboehm.com/articles/22/CUDA-MMM
2. https://space.keter.top/docs/high_performance/GEMM%E4%BC%98%E5%8C%96%E4%B8%93%E9%A2%98/%E4%BA%8C%E7%BB%B4Thread%20Tile%E5%B9%B6%E8%A1%8C%E4%BC%98%E5%8C%96
Expand Down
27 changes: 23 additions & 4 deletions docs/11_gemm_optimize/01_tiled2d/sgemm_tiled2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,22 @@
#include <cuda_runtime.h>
#include <cassert>

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))

template <typename T>
void free_resource(T *p_cpu, T *p_cuda)
{
if (nullptr != p_cpu)
{
delete[] p_cpu;
p_cpu = nullptr;
}
if (nullptr != p_cuda)
{
cudaFree(p_cuda);
p_cuda = nullptr;
}
}

void sgemm_naive_cpu(float *A, float *B, float *C, int M, int N, int K)
{
Expand Down Expand Up @@ -155,7 +170,7 @@ int main(int argc, char *argv[])

// Allocate memory for matrices
float *A, *B, *C, *C_ref;
float *d_A, *d_B, *d_C, *d_C_ref;
float *d_A, *d_B, *d_C;

A = new float[m * k];
B = new float[k * n];
Expand All @@ -176,13 +191,11 @@ int main(int argc, char *argv[])
cudaMalloc((void **)&d_A, m * k * sizeof(float));
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));
cudaMalloc((void **)&d_C_ref, m * n * sizeof(float));

// Copy matrices to device
cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_C, C, m * n * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_C_ref, C_ref, m * n * sizeof(float), cudaMemcpyHostToDevice);

run_sgemm_blocktiling_2d(d_A, d_B, d_C, m, n, k);

Expand Down Expand Up @@ -219,5 +232,11 @@ int main(int argc, char *argv[])
cudaEventElapsedTime(&elapsed_time, start, stop);
float avg_run_time = elapsed_time * 1000 / 100;
printf("Average run time: %f us\n", avg_run_time);

free_resource(A, d_A);
free_resource(B, d_B);
free_resource(C, d_C);
free_resource(C_ref, (float *)nullptr);

return 0;
}

0 comments on commit a6b8d59

Please sign in to comment.