14
14
#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
15
15
#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
16
16
#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
17
+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
17
18
18
19
// Load matrix to REGISTER
19
20
#define LDMATRIX_X4 (R0, R1, R2, R3, addr ) \
@@ -62,7 +63,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
62
63
int new_dim_x = dim_x % 16 ;
63
64
int new_dim_y = (dim_y / 16 * (d / 16 ) * 16 ) + (dim_x / 16 * 16 ) + (dim_y % 16 );
64
65
65
- FLOAT4 (Qi[new_dim_y * 16 + new_dim_x]) = FLOAT4 (Q[qkv_offset + (i * tile_size) + x]);
66
+ LDST128BITS (Qi[new_dim_y * 16 + new_dim_x]) = LDST128BITS (Q[qkv_offset + (i * tile_size) + x]);
66
67
}
67
68
__syncthreads ();
68
69
@@ -92,7 +93,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
92
93
int new_dim_x = dim_x % 16 ;
93
94
int new_dim_y = (dim_y / 16 * (d / 16 ) * 16 ) + (dim_x / 16 * 16 ) + (dim_y % 16 );
94
95
95
- FLOAT4 (Kj[new_dim_y * 16 + new_dim_x]) = FLOAT4 (K[qkv_offset + (j * tile_size) + x]);
96
+ LDST128BITS (Kj[new_dim_y * 16 + new_dim_x]) = LDST128BITS (K[qkv_offset + (j * tile_size) + x]);
96
97
}
97
98
__syncthreads ();
98
99
@@ -124,7 +125,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
124
125
125
126
// Read V from global memory to shared memory
126
127
for (int x = threadIdx .x * 8 ; x < tile_size; x += 1024 ) {
127
- FLOAT4 (reg[0 ]) = FLOAT4 (V[qkv_offset + (j * tile_size) + x]);
128
+ LDST128BITS (reg[0 ]) = LDST128BITS (V[qkv_offset + (j * tile_size) + x]);
128
129
129
130
int dim_x = x % d;
130
131
int dim_y = x / d;
@@ -142,10 +143,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
142
143
// adapt from https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization/blob/main/inc/fmha_i8.cuh
143
144
// Softmax phase (m, l calculate)
144
145
// FETCHING REGISTER
145
- FLOAT4 (reg[0 ]) = FLOAT4 (RC[0 ][0 ]);
146
- FLOAT4 (reg[8 ]) = FLOAT4 (RC[2 ][0 ]);
147
- FLOAT4 (reg[16 ]) = FLOAT4 (RC[4 ][0 ]);
148
- FLOAT4 (reg[24 ]) = FLOAT4 (RC[6 ][0 ]);
146
+ LDST128BITS (reg[0 ]) = LDST128BITS (RC[0 ][0 ]);
147
+ LDST128BITS (reg[8 ]) = LDST128BITS (RC[2 ][0 ]);
148
+ LDST128BITS (reg[16 ]) = LDST128BITS (RC[4 ][0 ]);
149
+ LDST128BITS (reg[24 ]) = LDST128BITS (RC[6 ][0 ]);
149
150
150
151
// thread level reduce max
151
152
#pragma unroll
@@ -197,10 +198,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
197
198
}
198
199
199
200
// FETCHING REGISTER for P
200
- FLOAT4 (RC[0 ][0 ]) = FLOAT4 (reg[0 ]);
201
- FLOAT4 (RC[2 ][0 ]) = FLOAT4 (reg[8 ]);
202
- FLOAT4 (RC[4 ][0 ]) = FLOAT4 (reg[16 ]);
203
- FLOAT4 (RC[6 ][0 ]) = FLOAT4 (reg[24 ]);
201
+ LDST128BITS (RC[0 ][0 ]) = LDST128BITS (reg[0 ]);
202
+ LDST128BITS (RC[2 ][0 ]) = LDST128BITS (reg[8 ]);
203
+ LDST128BITS (RC[4 ][0 ]) = LDST128BITS (reg[16 ]);
204
+ LDST128BITS (RC[6 ][0 ]) = LDST128BITS (reg[24 ]);
204
205
205
206
// P @ V
206
207
for (int k = 0 ; k < d / 16 ; k++) {
@@ -220,7 +221,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
220
221
RD[2 ], RD[3 ]);
221
222
}
222
223
223
- FLOAT4 (reg[0 ]) = FLOAT4 (RD[0 ]);
224
+ LDST128BITS (reg[0 ]) = LDST128BITS (RD[0 ]);
224
225
#pragma unroll
225
226
for (int tc_yi = 0 ; tc_yi < 2 ; tc_yi++) {
226
227
float thread_max_new = max (thread_max_old[tc_yi], thread_max[tc_yi]);
0 commit comments