Skip to content

Commit 2ada4ec

Browse files
[CI] Removes debug print statements from the example. (#1030)
* [CI] Removes debug print statements from the example. * add parse args * [Lint]: [pre-commit.ci] auto fixes [...] * format --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e59e7f9 commit 2ada4ec

File tree

1 file changed

+57
-12
lines changed

1 file changed

+57
-12
lines changed

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from dequantize_utils import torch_convert_bit_twiddling, assert_similar
88
from tilelang.autotuner import set_autotune_inputs
9+
import argparse
910

1011

1112
def get_configs():
@@ -433,13 +434,18 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
433434
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
434435
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
435436

436-
print(f'{sorted_token_ids=}')
437-
print(f'{expert_ids=}')
438-
439437
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
440438

441439

442-
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, topk=4, E=32):
440+
def main(m=256,
441+
n=256,
442+
k=256,
443+
scale_size=32,
444+
topk=4,
445+
E=32,
446+
fast_dequant=True,
447+
with_bias=False,
448+
tune=False):
443449
# Tunable parameters
444450
block_M, block_N, block_K = 128, 256, 128 # noqa: F841
445451
num_stages = 1 # noqa: F841
@@ -453,8 +459,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
453459
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(
454460
m, n, k, qk, scale_size, topk, E, block_M)
455461

456-
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
457-
# Autotune with inputs manually composed
462+
if tune:
463+
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
464+
# Autotune with inputs manually composed
465+
kernel = matmul(
466+
m,
467+
n,
468+
k,
469+
topk,
470+
E,
471+
padding_M,
472+
"bfloat16",
473+
"bfloat16",
474+
"float32",
475+
num_bits=num_bits,
476+
scale_size=scale_size,
477+
fast_dequant=fast_dequant,
478+
with_bias=with_bias,
479+
)
480+
else:
458481
kernel = matmul(
459482
m,
460483
n,
@@ -469,8 +492,13 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
469492
scale_size=scale_size,
470493
fast_dequant=fast_dequant,
471494
with_bias=with_bias,
495+
block_M=block_M,
496+
block_N=block_N,
497+
block_K=block_K,
498+
num_stages=num_stages,
499+
threads=threads,
500+
split=split,
472501
)
473-
print(f'Best config: {kernel.config}')
474502

475503
output = kernel(
476504
A,
@@ -504,8 +532,25 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
504532

505533

506534
if __name__ == "__main__":
507-
M, N, K = 16384, 5760, 2944 # From gpt-oss-20b MoE's first gemm
508-
scale_size = 32
509-
topk = 4 # experts activated for each token
510-
E = 32 # number of experts
511-
main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E)
535+
parser = argparse.ArgumentParser()
536+
parser.add_argument(
537+
"--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
538+
parser.add_argument("--N", type=int, default=5760, help="N")
539+
parser.add_argument("--K", type=int, default=2944, help="K")
540+
parser.add_argument("--scale_size", type=int, default=32, help="scale size")
541+
parser.add_argument(
542+
"--topk", type=int, default=4, help="topk") # experts activated for each token
543+
parser.add_argument("--E", type=int, default=32, help="E") # number of experts
544+
parser.add_argument("--tune", action="store_true", help="tune configs")
545+
args = parser.parse_args()
546+
547+
main(
548+
args.M,
549+
args.N,
550+
args.K,
551+
args.scale_size,
552+
topk=args.topk,
553+
E=args.E,
554+
fast_dequant=True,
555+
with_bias=True,
556+
tune=args.tune)

0 commit comments

Comments
 (0)