Skip to content

Commit

Permalink
[swizzle] add padding -> swizzle layout tools🎉 (#198)
Browse files Browse the repository at this point in the history
* Update README.md

* add pad -> swizzle layout tools
  • Loading branch information
DefTruth authored Dec 28, 2024
1 parent fd993a9 commit 6c811c9
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 31 deletions.
46 changes: 37 additions & 9 deletions kernels/flash-attn/tools/print_swizzle_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,


def print_smem_swizzle_layout(rows: int = 16,
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
smem_pading: int = 0,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
# ----------------------------------------------------------------
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
# [INFO] For logical_col_stride > 16, we have to permute the |
Expand Down Expand Up @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
# ----------------------------------------------------------------
str_len = 0
total_banks = 0
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
# 4 bytes per bank
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
if use_logical_col_stride:
banks_per_col = int((logical_col_stride * 2) / 4)
if logical_col_stride > 16:
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
if smem_pading == 8:
banks_per_col += 4
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")

banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
for i in range(rows):
layout_str_len = 0
Expand Down Expand Up @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
num_elems_per_128b)
logical_col_ids.append(j)
smem_layout_col_ids.append(layout_j)

smem_layout_str = f"|row {i:<2}|"

r = 0
for c, l in zip(logical_col_ids, smem_layout_col_ids):
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
show_logical_col_id else f"{l:<2}"),
sep=" ",
width=max_bank_str_len-1,
return_str=True) + "|"
smem_layout_str += pretty_print_line(
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
sep=" ",
width=(max_bank_str_len-1),
return_str=True
) + "|"
r += 1
if logical_col_stride >= 16:
if smem_pading == 8 and (r > 1 and r % 2 == 0):
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"
else:
if smem_pading == 8:
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"

layout_str_len = len(smem_layout_str)
str_len = max(layout_str_len, banks_str_len)

Expand All @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=16)
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
Expand All @@ -186,6 +213,7 @@ def get_args():
print_smem_swizzle_layout(rows=args.rows,
logical_col_stride=args.logical_col_stride,
num_elems_per_128b=args.num_elems_per_128b,
smem_pading=args.smem_padding,
show_logical_col_id=args.show_logical_col_id,
use_logical_col_stride=args.use_logical_col_stride)

46 changes: 37 additions & 9 deletions kernels/hgemm/tools/print_swizzle_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,


def print_smem_swizzle_layout(rows: int = 16,
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
smem_pading: int = 0,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
# ----------------------------------------------------------------
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
# [INFO] For logical_col_stride > 16, we have to permute the |
Expand Down Expand Up @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
# ----------------------------------------------------------------
str_len = 0
total_banks = 0
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
# 4 bytes per bank
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
if use_logical_col_stride:
banks_per_col = int((logical_col_stride * 2) / 4)
if logical_col_stride > 16:
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
if smem_pading == 8:
banks_per_col += 4
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")

banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
for i in range(rows):
layout_str_len = 0
Expand Down Expand Up @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
num_elems_per_128b)
logical_col_ids.append(j)
smem_layout_col_ids.append(layout_j)

smem_layout_str = f"|row {i:<2}|"

r = 0
for c, l in zip(logical_col_ids, smem_layout_col_ids):
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
show_logical_col_id else f"{l:<2}"),
sep=" ",
width=max_bank_str_len-1,
return_str=True) + "|"
smem_layout_str += pretty_print_line(
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
sep=" ",
width=(max_bank_str_len-1),
return_str=True
) + "|"
r += 1
if logical_col_stride >= 16:
if smem_pading == 8 and (r > 1 and r % 2 == 0):
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"
else:
if smem_pading == 8:
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"

layout_str_len = len(smem_layout_str)
str_len = max(layout_str_len, banks_str_len)

Expand All @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=16)
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
Expand All @@ -186,6 +213,7 @@ def get_args():
print_smem_swizzle_layout(rows=args.rows,
logical_col_stride=args.logical_col_stride,
num_elems_per_128b=args.num_elems_per_128b,
smem_pading=args.smem_padding,
show_logical_col_id=args.show_logical_col_id,
use_logical_col_stride=args.use_logical_col_stride)

16 changes: 12 additions & 4 deletions kernels/swizzle/hgemm_mma_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
constexpr int MMA_TILE_N = 4;
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
// bank conflicts free via pad = 8, 拒绝幻想,相信profile
// bank conflicts free via pad = 8.
// ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin
// ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin
// constexpr int A_PAD = 8;
Expand All @@ -541,6 +541,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
);
}

template <const int B_PAD = 8>
void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
half* a, half* b, half* c, int M, int N, int K) {
constexpr int MMA_M = 16;
Expand All @@ -551,7 +552,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
constexpr int A_PAD = 0;
constexpr int B_PAD = 8;
// B_PAD = 8, bank conflicts free via pad = 8.
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
dim3 block(NUM_THREADS);
Expand Down Expand Up @@ -644,9 +645,16 @@ int main(int argc, char *argv[]) {
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);

printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 0\n");
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<0>,
M, N, K, W, R);
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);

printf("\nALGO = HGEMM mma2x4_warp4x4 + SMEM SWIZZLE\n");
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle,
printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 8\n");
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<8>,
M, N, K, W, R);
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
Expand Down
46 changes: 37 additions & 9 deletions kernels/swizzle/print_swizzle_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,


def print_smem_swizzle_layout(rows: int = 16,
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
logical_col_stride: int = 16,
num_elems_per_128b: int = 8,
smem_pading: int = 0,
show_logical_col_id: bool = False,
use_logical_col_stride: bool = False):
# ----------------------------------------------------------------
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
# [INFO] For logical_col_stride > 16, we have to permute the |
Expand Down Expand Up @@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
# ----------------------------------------------------------------
str_len = 0
total_banks = 0
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
# 4 bytes per bank
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
if use_logical_col_stride:
banks_per_col = int((logical_col_stride * 2) / 4)
if logical_col_stride > 16:
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
if smem_pading == 8:
banks_per_col += 4
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")

banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
for i in range(rows):
layout_str_len = 0
Expand Down Expand Up @@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
num_elems_per_128b)
logical_col_ids.append(j)
smem_layout_col_ids.append(layout_j)

smem_layout_str = f"|row {i:<2}|"

r = 0
for c, l in zip(logical_col_ids, smem_layout_col_ids):
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
show_logical_col_id else f"{l:<2}"),
sep=" ",
width=max_bank_str_len-1,
return_str=True) + "|"
smem_layout_str += pretty_print_line(
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
sep=" ",
width=(max_bank_str_len-1),
return_str=True
) + "|"
r += 1
if logical_col_stride >= 16:
if smem_pading == 8 and (r > 1 and r % 2 == 0):
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"
else:
if smem_pading == 8:
smem_layout_str += pretty_print_line(
(f"pad"),
sep=" ", width=max_bank_str_len-1,
return_str=True
) + "|"

layout_str_len = len(smem_layout_str)
str_len = max(layout_str_len, banks_str_len)

Expand All @@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=16)
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
Expand All @@ -186,6 +213,7 @@ def get_args():
print_smem_swizzle_layout(rows=args.rows,
logical_col_stride=args.logical_col_stride,
num_elems_per_128b=args.num_elems_per_128b,
smem_pading=args.smem_padding,
show_logical_col_id=args.show_logical_col_id,
use_logical_col_stride=args.use_logical_col_stride)

0 comments on commit 6c811c9

Please sign in to comment.