Skip to content

Commit 4be041f

Browse files
authored
[FlashAttention] replace FLOAT4 with LDST128BITS macro (xlite-dev#41)
1 parent 068e6fe commit 4be041f

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

flash-attn/flash_attn_2_fwd_f16_mma_m16n8k16.cu

+13-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
1515
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
1616
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
17+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
1718

1819
// Load matrix to REGISTER
1920
#define LDMATRIX_X4(R0, R1, R2, R3, addr) \
@@ -62,7 +63,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
6263
int new_dim_x = dim_x % 16;
6364
int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16);
6465

65-
FLOAT4(Qi[new_dim_y * 16 + new_dim_x]) = FLOAT4(Q[qkv_offset + (i * tile_size) + x]);
66+
LDST128BITS(Qi[new_dim_y * 16 + new_dim_x]) = LDST128BITS(Q[qkv_offset + (i * tile_size) + x]);
6667
}
6768
__syncthreads();
6869

@@ -92,7 +93,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
9293
int new_dim_x = dim_x % 16;
9394
int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16);
9495

95-
FLOAT4(Kj[new_dim_y * 16 + new_dim_x]) = FLOAT4(K[qkv_offset + (j * tile_size) + x]);
96+
LDST128BITS(Kj[new_dim_y * 16 + new_dim_x]) = LDST128BITS(K[qkv_offset + (j * tile_size) + x]);
9697
}
9798
__syncthreads();
9899

@@ -124,7 +125,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
124125

125126
// Read V from global memory to shared memory
126127
for (int x = threadIdx.x * 8; x < tile_size; x += 1024) {
127-
FLOAT4(reg[0]) = FLOAT4(V[qkv_offset + (j * tile_size) + x]);
128+
LDST128BITS(reg[0]) = LDST128BITS(V[qkv_offset + (j * tile_size) + x]);
128129

129130
int dim_x = x % d;
130131
int dim_y = x / d;
@@ -142,10 +143,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
142143
// adapt from https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization/blob/main/inc/fmha_i8.cuh
143144
// Softmax phase (m, l calculate)
144145
// FETCHING REGISTER
145-
FLOAT4(reg[0]) = FLOAT4(RC[0][0]);
146-
FLOAT4(reg[8]) = FLOAT4(RC[2][0]);
147-
FLOAT4(reg[16]) = FLOAT4(RC[4][0]);
148-
FLOAT4(reg[24]) = FLOAT4(RC[6][0]);
146+
LDST128BITS(reg[0]) = LDST128BITS(RC[0][0]);
147+
LDST128BITS(reg[8]) = LDST128BITS(RC[2][0]);
148+
LDST128BITS(reg[16]) = LDST128BITS(RC[4][0]);
149+
LDST128BITS(reg[24]) = LDST128BITS(RC[6][0]);
149150

150151
// thread level reduce max
151152
#pragma unroll
@@ -197,10 +198,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
197198
}
198199

199200
// FETCHING REGISTER for P
200-
FLOAT4(RC[0][0]) = FLOAT4(reg[0]);
201-
FLOAT4(RC[2][0]) = FLOAT4(reg[8]);
202-
FLOAT4(RC[4][0]) = FLOAT4(reg[16]);
203-
FLOAT4(RC[6][0]) = FLOAT4(reg[24]);
201+
LDST128BITS(RC[0][0]) = LDST128BITS(reg[0]);
202+
LDST128BITS(RC[2][0]) = LDST128BITS(reg[8]);
203+
LDST128BITS(RC[4][0]) = LDST128BITS(reg[16]);
204+
LDST128BITS(RC[6][0]) = LDST128BITS(reg[24]);
204205

205206
// P @ V
206207
for (int k = 0; k < d / 16; k++) {
@@ -220,7 +221,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel(
220221
RD[2], RD[3]);
221222
}
222223

223-
FLOAT4(reg[0]) = FLOAT4(RD[0]);
224+
LDST128BITS(reg[0]) = LDST128BITS(RD[0]);
224225
#pragma unroll
225226
for(int tc_yi = 0; tc_yi < 2; tc_yi++) {
226227
float thread_max_new = max(thread_max_old[tc_yi], thread_max[tc_yi]);

0 commit comments

Comments
 (0)