Skip to content

Commit 49c176b

Browse files
committed
Add runtime swap AB for SM100 FP8 blockwise GEMM
Signed-off-by: Barry Kang <[email protected]>
1 parent a15e333 commit 49c176b

File tree

2 files changed

+77
-15
lines changed

2 files changed

+77
-15
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -574,14 +574,29 @@ def apply(self, module: Linear, input: torch.Tensor,
574574

575575
if get_sm_version() == 100:
576576
import deep_gemm
577-
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
578-
output = torch.empty((input.shape[0], module.weight.shape[0]),
579-
device=input.device,
580-
dtype=torch.bfloat16)
581-
deep_gemm.fp8_gemm_nt((a, a_sf),
582-
(module.weight, module.weight_scale),
583-
output,
584-
disable_ue8m0_cast=True)
577+
if input.shape[0] < 128:
578+
# Swap AB
579+
a, a_sf = fp8_utils.per_token_quant_and_transform(input,
580+
swap_ab=True)
581+
output_padded = torch.empty(
582+
(module.weight.shape[0], a.shape[0]),
583+
device=input.device,
584+
dtype=torch.bfloat16)
585+
deep_gemm.fp8_gemm_nt((module.weight, module.weight_scale),
586+
(a, a_sf),
587+
output_padded,
588+
disable_ue8m0_cast=True)
589+
output = fp8_utils.masked_transpose(output_padded,
590+
input.shape[0])
591+
else:
592+
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
593+
output = torch.empty((input.shape[0], module.weight.shape[0]),
594+
device=input.device,
595+
dtype=torch.bfloat16)
596+
deep_gemm.fp8_gemm_nt((a, a_sf),
597+
(module.weight, module.weight_scale),
598+
output,
599+
disable_ue8m0_cast=True)
585600
else:
586601
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
587602
input)

tensorrt_llm/quantization/utils/fp8_utils.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def silu_and_mul_masked_post_quant_fwd(
336336
scale_k = ceil_div(k, quant_group_size)
337337
m_padded = align(m, alignment)
338338
scale_k_padded = align(scale_k, alignment)
339-
output_scale = torch.zeros((g, scale_k_padded // 4, m_padded),
339+
output_scale = torch.empty((g, scale_k_padded // 4, m_padded),
340340
dtype=torch.int32,
341341
device='cuda')
342342

@@ -458,6 +458,7 @@ def per_token_quant_and_transform(
458458
input: torch.Tensor,
459459
quant_group_size: int = 128,
460460
scale_ue8m0: bool = True,
461+
swap_ab=False,
461462
):
462463
"""
463464
input shape [g, m, k]
@@ -477,18 +478,21 @@ def per_token_quant_and_transform(
477478
fp8_min = -fp8_max
478479

479480
m, k = input.shape
481+
m_padded = m if not swap_ab else align(m, 8)
480482

481483
# Create output
482-
output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda")
484+
output = torch.empty((m_padded, k),
485+
dtype=torch.float8_e4m3fn,
486+
device=input.device)
483487

484488
# Create output scale
485489
alignment = 4
486490
scale_k = ceil_div(k, quant_group_size)
487-
m_padded = align(m, alignment)
491+
m_aligned = align(m_padded, alignment)
488492
scale_k_padded = align(scale_k, alignment)
489-
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
493+
output_scale = torch.empty((scale_k_padded // 4, m_aligned),
490494
dtype=torch.int32,
491-
device='cuda')
495+
device=input.device)
492496

493497
# Get block/grid/stage/warp
494498
BLOCK_NUM_PER_EXPERT = 64
@@ -518,13 +522,56 @@ def per_token_quant_and_transform(
518522
num_warps=num_warps,
519523
SCALE_UE8M0=scale_ue8m0,
520524
)
521-
output_scale = output_scale.transpose(0, 1)[:m, :]
525+
output_scale = output_scale.transpose(0, 1)[:m_padded, :]
522526
check_sf_layout(
523527
output_scale,
524-
m,
528+
m_padded,
525529
k,
526530
(1, 128),
527531
num_groups=None,
528532
tma_stride_check=True,
529533
)
530534
return output, output_scale
535+
536+
537+
@triton.jit
538+
def _transpose_kernel(input_ptr, output_ptr, M, N, stride_in_m, stride_in_n,
539+
stride_out_m, stride_out_n, BLOCK_SIZE: tl.constexpr):
540+
row_block = tl.program_id(0)
541+
col_block = tl.program_id(1)
542+
543+
row = row_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
544+
col = col_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
545+
546+
mask_row = row < M
547+
mask_col = col < N
548+
mask = mask_row[:, None] & mask_col[None, :]
549+
550+
input_idx = row[:, None] * stride_in_m + col[None, :] * stride_in_n
551+
data = tl.load(input_ptr + input_idx, mask=mask, other=0)
552+
553+
output_idx = row[:, None] * stride_out_n + col[None, :] * stride_out_m
554+
tl.store(output_ptr + output_idx, data, mask=mask)
555+
556+
557+
def masked_transpose(input: torch.Tensor, n_available: int) -> torch.Tensor:
558+
M, N = input.shape
559+
BLOCK_SIZE = 32
560+
output = torch.empty((n_available, M),
561+
dtype=input.dtype,
562+
device=input.device)
563+
564+
grid = ((M + BLOCK_SIZE - 1) // BLOCK_SIZE,
565+
(n_available + BLOCK_SIZE - 1) // BLOCK_SIZE)
566+
_transpose_kernel[grid](
567+
input,
568+
output,
569+
M,
570+
n_available,
571+
input.stride(0),
572+
input.stride(1),
573+
output.stride(0),
574+
output.stride(1),
575+
BLOCK_SIZE=BLOCK_SIZE,
576+
)
577+
return output

0 commit comments

Comments
 (0)