Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion benchmarks/routines/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def parse_norm_args(line, parser):
default=False,
help="Use swizzled scale factor layout for tensor core GEMM. Default: False",
)
parser.add_argument(
"--output_both_sf_layouts",
action="store_true",
default=False,
help="Output both swizzled and unswizzled scale factors. When enabled, "
"overrides --is_sf_swizzled_layout and returns both layouts. Default: False",
)

args = parser.parse_args(line)
if args.verbose >= 1:
Expand Down Expand Up @@ -799,6 +806,13 @@ def testRmsnormFp4quant(args):
if run_refcheck:
print("[WARNING] --refcheck is not supported for rmsnorm_fp4quant.")

# Warn user that output_both_sf_layouts is not supported for rmsnorm_fp4quant
if args.output_both_sf_layouts:
print(
"[WARNING] --output_both_sf_layouts is not supported for rmsnorm_fp4quant. "
"Use add_rmsnorm_fp4quant instead. Flag will be ignored."
)

def run_backend(backend, input_tensor, weight):
if backend == "cute-dsl":
return flashinfer.rmsnorm_fp4quant(
Expand Down Expand Up @@ -912,6 +926,7 @@ def testAddRmsnormFp4quant(args):
out_dtype = args.out_dtype
use_global_scale = args.use_global_scale
is_sf_swizzled_layout = args.is_sf_swizzled_layout
output_both_sf_layouts = args.output_both_sf_layouts
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
res = []
Expand Down Expand Up @@ -976,6 +991,7 @@ def testAddRmsnormFp4quant(args):
print(f"[VVERBOSE] {block_size = }")
print(f"[VVERBOSE] {use_global_scale = }")
print(f"[VVERBOSE] {is_sf_swizzled_layout = }")
print(f"[VVERBOSE] {output_both_sf_layouts = }")

# Warn user that refcheck is not supported for FP4 quantization fusion
if run_refcheck:
Expand All @@ -991,6 +1007,7 @@ def run_backend(backend, input_tensor, residual_tensor, weight):
block_size=block_size,
global_scale=global_scale,
is_sf_swizzled_layout=is_sf_swizzled_layout,
output_both_sf_layouts=output_both_sf_layouts,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
Expand Down Expand Up @@ -1019,14 +1036,17 @@ def run_backend(backend, input_tensor, residual_tensor, weight):
num_scale_elements = num_elements // block_size
# FP4: 2 elements per byte (4 bits each)
fp4_output_bytes = num_elements // 2
# Scale factors: 1 byte each. When output_both_sf_layouts=True, write 2x scale factors
sf_write_multiplier = 2 if output_both_sf_layouts else 1
problem_bytes = (
num_elements * input_dtype.itemsize # input read
+ num_elements * input_dtype.itemsize # residual read
+ hidden_size * input_dtype.itemsize # weight read
+ num_elements
* input_dtype.itemsize # residual write (in-place: input + residual)
+ fp4_output_bytes # FP4 output write
+ num_scale_elements # scale factors write (1 byte each)
+ num_scale_elements
* sf_write_multiplier # scale factors write (1 byte each)
)
problem_flops = num_elements * 6 # rough estimate (add + rmsnorm ops)
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
Expand All @@ -1047,6 +1067,7 @@ def run_backend(backend, input_tensor, residual_tensor, weight):
cur_res["eps"] = eps
cur_res["use_global_scale"] = use_global_scale
cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout
cur_res["output_both_sf_layouts"] = output_both_sf_layouts
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
Expand Down
10 changes: 10 additions & 0 deletions benchmarks/samples/sample_testlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@
# 3D input shape (batch, num_heads, head_dim)
--routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_3d"

# Output both swizzled and unswizzled scale factors (for dual-use scenarios)
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_both_sf"
--routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_both_sf_large"

# Both SF layouts with global scale
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_both_sf_global"

# Both SF layouts with MXFP4 format
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4_both_sf"

## Quantization (Blackwell SM10.0+ only)
# MxFP8 Quantization - basic
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_basic"
Expand Down
Loading
Loading