From 66b028db0b8eeb1b14c37d95ffb8f5a4244b4da0 Mon Sep 17 00:00:00 2001 From: liu jiawei Date: Tue, 3 Dec 2024 10:42:54 +0800 Subject: [PATCH] [Doc][Polish] gemm optimize by 2d thread tile, modify doc by review --- docs/11_gemm_optimize/01_tiled2d/README.md | 38 +++++++++++-------- .../01_tiled2d/sgemm_tiled2d.cu | 27 +++++++++++-- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/docs/11_gemm_optimize/01_tiled2d/README.md b/docs/11_gemm_optimize/01_tiled2d/README.md index 6ac6d68..1632e08 100644 --- a/docs/11_gemm_optimize/01_tiled2d/README.md +++ b/docs/11_gemm_optimize/01_tiled2d/README.md @@ -6,7 +6,13 @@ 在介绍二维 Thread Tile 之前,我们先来回顾一下一维 Thread Tile 的优化方法。在初级系列中,我们使用了一维线程块来优化矩阵乘法的性能,我们将矩阵乘法的计算任务分配给了一维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。 -我们在每个线程中计算了一维的矩阵块。想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。 +还记得一维 Thread Tile 中的例子吗?如果输入的 A 和 B 都是 7x7 的矩阵: + +1. 如果我们一次读取 1 行 A 和 1 列 B,当每一个线程只计算一个结果的时候,我们需要从 A 中读取 7 个数据,从 B 中读取 7 个数据,从 C 中读取 1 个数据,然后写 1 次 C。这样的话,每个线程需要读取 15 个数据,写 1 次数据。计算 16 个结果需要 16 个线程,共 16x16 = 256 次 IO。 +2. 如果我们一次读取 4 行 A 和 1 列 B,那么每一个线程计算 4 个结果,此时需要从 A 中读取 4x7 个数据,从 B 中读取 7 个数据,从 C 中读取 4 个数据,然后写 4 次 C。计算 16 个结果需要 4 个线程,共 4x43 = 172 次 IO。 +3. 如果我们一次读取 4 行 A 和 4 列 B,那么每一个线程计算 16 个结果,此时需要从 A 中读取 4x7 个数据,从 B 中读取 4x7 个数据,从 C 中读取 16 个数据,然后写 16 次 C。计算 16 个结果一共需要 1 个线程,共 1x88 = 88 次 IO。 + +上述的 `2` 就是一维 Thread Tile 优化,上述的 `3` 就是 二维 Thread Tile 优化,计算结果不变的同时,减少 IO 次数,提升算法的执行时间。所以想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。 ## 2. 二维 Thread Tile @@ -14,7 +20,7 @@ 本文的主要优化思路就是让每个线程计算一个 8\*8 的网格。下面我们来看一下这个 Kernel 的主题思路图: -![picture 1](images/9047246849f79b5117961c15e1a3a340a44ab003566140ecc00600058c70a9e2.png) +![picture 1](images/9047246849f79b5117961c15e1a3a340a44ab003566140ecc00600058c70a9e2.png) 首先在内核的第一阶段, 所有线程协同工作, 从全局内存中加载矩阵 A 和矩阵 B 到共享内存中。 @@ -74,9 +80,9 @@ float thread_results[TM * TN] = {0.0}; float reg_m[TM] = {0.0}; float reg_n[TN] = {0.0}; -A += c_row * BM * K; -B += c_col * BN; -C += c_row * BM * N + c_col * BN; +A += c_row * BM * K; +B += c_col * BN; +C += c_row * BM * N + c_col * BN; // 外层循环 for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) @@ -107,7 +113,7 @@ B += BK * N; 下图可以更好的帮助我们理解上面的代码: -![picture 2](images/f507ad687528e8bbb14a85c1fa3016cce50be55b5670ebc425c549cc5c5bd5a6.png) +![picture 2](images/f507ad687528e8bbb14a85c1fa3016cce50be55b5670ebc425c549cc5c5bd5a6.png) 图中画出了矩阵 A 加载共享内存的过程。在每一步中,每个线程负责加载一个元素到共享内存中。这个元素的索引是 `inner_row_A` 和 `inner_col_A` 。for 循环中的 `load_offset` 递增的步长是 `stride_A` 。在图中就是向下移动了 `stride_A` 个元素。 @@ -130,7 +136,7 @@ for (uint dot_idx = 0; dot_idx < BK; ++dot_idx) { for (uint reg_idx_n = 0; reg_idx_n < TN; ++reg_idx_n) { - thread_results[reg_idx_m * TN + reg_idx_n] += + thread_results[reg_idx_m * TN + reg_idx_n] += reg_m[reg_idx_m] * reg_n[reg_idx_n]; } } @@ -158,7 +164,7 @@ for (uint reg_idx_m = 0; reg_idx_m < TM; ++reg_idx_m) { for (uint reg_idx_n = 0; reg_idx_n < TN; ++reg_idx_n) { - C[(thread_row * TM + reg_idx_m) * N + thread_col * TN + reg_idx_n] = + C[(thread_row * TM + reg_idx_m) * N + thread_col * TN + reg_idx_n] = thread_results[reg_idx_m * TN + reg_idx_n]; } } @@ -174,20 +180,20 @@ nvcc -o sgemm_tiled2d sgemm_tiled2d.cu ## 3. 性能测试 我们将上该内核的性能和之前的内核进行比较,我们分别计算 256x256、512x512、1024x1024、2048x2048 (Matrix 1、Matrix 2、Matrix 3、Matrix 4、Matrix 5)的矩阵乘法的性能 (us)。在 1080Ti 上运行,结果如下: - -| Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 | -| --------- | -------- | -------- | -------- | -------- | -| Naive | 95.5152 | 724.396 | 28424 | 228681 | -| 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 | -| 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 | -| 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 | + +| Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 | +| ---------------- | -------- | -------- | -------- | -------- | +| Naive | 95.5152 | 724.396 | 28424 | 228681 | +| 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 | +| 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 | +| 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 | ## 4. 总结 本文我们介绍了二维 Thread Tile 并行优化方法。我们将矩阵乘法的计算任务分配给了二维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。 -## Reference +## Reference 1. https://siboehm.com/articles/22/CUDA-MMM 2. https://space.keter.top/docs/high_performance/GEMM%E4%BC%98%E5%8C%96%E4%B8%93%E9%A2%98/%E4%BA%8C%E7%BB%B4Thread%20Tile%E5%B9%B6%E8%A1%8C%E4%BC%98%E5%8C%96 diff --git a/docs/11_gemm_optimize/01_tiled2d/sgemm_tiled2d.cu b/docs/11_gemm_optimize/01_tiled2d/sgemm_tiled2d.cu index 1a6361d..615c84a 100644 --- a/docs/11_gemm_optimize/01_tiled2d/sgemm_tiled2d.cu +++ b/docs/11_gemm_optimize/01_tiled2d/sgemm_tiled2d.cu @@ -2,7 +2,22 @@ #include #include -#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) + +template +void free_resource(T *p_cpu, T *p_cuda) +{ + if (nullptr != p_cpu) + { + delete[] p_cpu; + p_cpu = nullptr; + } + if (nullptr != p_cuda) + { + cudaFree(p_cuda); + p_cuda = nullptr; + } +} void sgemm_naive_cpu(float *A, float *B, float *C, int M, int N, int K) { @@ -155,7 +170,7 @@ int main(int argc, char *argv[]) // Allocate memory for matrices float *A, *B, *C, *C_ref; - float *d_A, *d_B, *d_C, *d_C_ref; + float *d_A, *d_B, *d_C; A = new float[m * k]; B = new float[k * n]; @@ -176,13 +191,11 @@ int main(int argc, char *argv[]) cudaMalloc((void **)&d_A, m * k * sizeof(float)); cudaMalloc((void **)&d_B, k * n * sizeof(float)); cudaMalloc((void **)&d_C, m * n * sizeof(float)); - cudaMalloc((void **)&d_C_ref, m * n * sizeof(float)); // Copy matrices to device cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(d_C, C, m * n * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(d_C_ref, C_ref, m * n * sizeof(float), cudaMemcpyHostToDevice); run_sgemm_blocktiling_2d(d_A, d_B, d_C, m, n, k); @@ -219,5 +232,11 @@ int main(int argc, char *argv[]) cudaEventElapsedTime(&elapsed_time, start, stop); float avg_run_time = elapsed_time * 1000 / 100; printf("Average run time: %f us\n", avg_run_time); + + free_resource(A, d_A); + free_resource(B, d_B); + free_resource(C, d_C); + free_resource(C_ref, (float *)nullptr); + return 0; } \ No newline at end of file