7
7
#include < cuda.h>
8
8
#include < cstdlib>
9
9
10
- #define CEIL_DIV (M, N ) (((M) + (N)- 1 ) / (N))
10
+ #define CEIL_DIV (M, N ) (((M) + (N) - 1 ) / (N))
11
11
#define OFFSET (row, col, ld ) ((row) * (ld) + (col))
12
12
#define FETCH_FLOAT4 (pointer ) (reinterpret_cast <float4 *>(&(pointer))[0 ])
13
13
14
+ void free_resource (float *ptr, int is_cuda = 1 )
15
+ {
16
+ if (nullptr != ptr)
17
+ {
18
+ if (is_cuda)
19
+ {
20
+ cudaFree (ptr);
21
+ }
22
+ else
23
+ {
24
+ delete[] ptr;
25
+ }
26
+ }
27
+ ptr = nullptr ;
28
+ }
29
+
14
30
void sgemm_naive_cpu (float *A, float *B, float *C, int M, int N, int K)
15
31
{
16
32
for (int x = 0 ; x < M; x++)
@@ -36,8 +52,8 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
36
52
const uint c_row = blockIdx .y ;
37
53
const uint c_col = blockIdx .x ;
38
54
39
- const int block_row_thread = BN / TN ;
40
- const int block_col_thread = BM / TM ;
55
+ const int block_row_thread = BM / TM ;
56
+ const int block_col_thread = BN / TN ;
41
57
// 一个线程负责计算 block 中 TM*TN 个元素
42
58
const int thread_num = block_row_thread * block_col_thread;
43
59
@@ -73,8 +89,8 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
73
89
C += c_row * BM * N + c_col * BN;
74
90
75
91
float thread_results[TM * TN] = {0.0 };
76
- // 每个线程搬运ldg_a_num轮,寄存器缓存ldg_a_num个float4元素,用于转置As矩阵
77
- float ldg_reg_a[4 * ldg_a_num ] = {0 .};
92
+ // 转置时,只用大小为 4 的数组就可以
93
+ float ldg_reg_a[4 ] = {0 .};
78
94
float reg_a[TM] = {0.0 }; // 缓存 smem_a
79
95
float reg_b[TN] = {0.0 }; // 缓存 smem_b
80
96
@@ -83,13 +99,12 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
83
99
{
84
100
for (int i = 0 ; i < BM; i += stride_a)
85
101
{
86
- int ldg_index = i / stride_a * 4 ;
87
- FETCH_FLOAT4 (ldg_reg_a[ldg_index]) = FETCH_FLOAT4 (A[OFFSET (i + inner_row_a, inner_col_a, K)]);
102
+ FETCH_FLOAT4 (ldg_reg_a[0 ]) = FETCH_FLOAT4 (A[OFFSET (i + inner_row_a, inner_col_a, K)]);
88
103
// smem_a 转置存,其中 ldg_reg_a 做中间缓存,目的是读取时可以按FLOAT4读取
89
- smem_a[OFFSET (inner_col_a, i + inner_row_a, BM)] = ldg_reg_a[ldg_index ];
90
- smem_a[OFFSET (inner_col_a + 1 , i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 1 ];
91
- smem_a[OFFSET (inner_col_a + 2 , i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 2 ];
92
- smem_a[OFFSET (inner_col_a + 3 , i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 3 ];
104
+ smem_a[OFFSET (inner_col_a, i + inner_row_a, BM)] = ldg_reg_a[0 ];
105
+ smem_a[OFFSET (inner_col_a + 1 , i + inner_row_a, BM)] = ldg_reg_a[1 ];
106
+ smem_a[OFFSET (inner_col_a + 2 , i + inner_row_a, BM)] = ldg_reg_a[2 ];
107
+ smem_a[OFFSET (inner_col_a + 3 , i + inner_row_a, BM)] = ldg_reg_a[3 ];
93
108
}
94
109
95
110
for (int i = 0 ; i < BK; i += stride_b)
@@ -166,7 +181,7 @@ int main(int argc, char *argv[])
166
181
167
182
// Allocate memory for matrices
168
183
float *A, *B, *C, *C_ref;
169
- float *d_A, *d_B, *d_C, *d_C_ref ;
184
+ float *d_A, *d_B, *d_C;
170
185
171
186
A = new float [m * k];
172
187
B = new float [k * n];
@@ -183,17 +198,10 @@ int main(int argc, char *argv[])
183
198
cudaMalloc ((void **)&d_B, k * n * sizeof (float ));
184
199
cudaMalloc ((void **)&d_C, m * n * sizeof (float ));
185
200
186
- // Copy data to device
187
- cudaMalloc ((void **)&d_A, m * k * sizeof (float ));
188
- cudaMalloc ((void **)&d_B, k * n * sizeof (float ));
189
- cudaMalloc ((void **)&d_C, m * n * sizeof (float ));
190
- cudaMalloc ((void **)&d_C_ref, m * n * sizeof (float ));
191
-
192
201
// Copy matrices to device
193
202
cudaMemcpy (d_A, A, m * k * sizeof (float ), cudaMemcpyHostToDevice);
194
203
cudaMemcpy (d_B, B, k * n * sizeof (float ), cudaMemcpyHostToDevice);
195
204
cudaMemcpy (d_C, C, m * n * sizeof (float ), cudaMemcpyHostToDevice);
196
- cudaMemcpy (d_C_ref, C_ref, m * n * sizeof (float ), cudaMemcpyHostToDevice);
197
205
198
206
run_sgemm_vectorize (d_A, d_B, d_C, m, n, k);
199
207
@@ -230,5 +238,15 @@ int main(int argc, char *argv[])
230
238
cudaEventElapsedTime (&elapsed_time, start, stop);
231
239
float avg_run_time = elapsed_time * 1000 / 100 ;
232
240
printf (" Average run time: %f us\n " , avg_run_time);
241
+
242
+ free_resource (A, 0 );
243
+ free_resource (B, 0 );
244
+ free_resource (C, 0 );
245
+ free_resource (C_ref, 0 );
246
+
247
+ free_resource (d_A, 1 );
248
+ free_resource (d_B, 1 );
249
+ free_resource (d_C, 1 );
250
+
233
251
return 0 ;
234
252
}
0 commit comments