33import csv
44from dataclasses import asdict , dataclass
55import functools
6- import itertools
76import os
87import sys
98
@@ -68,15 +67,10 @@ def asdict(self):
6867
6968def generate_experiment_configs (
7069 dtype : torch .dtype ,
71- M : list [int ],
72- N : list [int ],
73- K : list [int ],
70+ shapes : list [tuple [int , int , int ]],
7471 backends : list [str ],
7572 device : torch .device ,
7673) -> list [ExperimentConfig ]:
77- # Generate cross config shapes from M, N, K lists
78- shapes = list (itertools .product (M , N , K ))
79-
8074 all_configs = []
8175 for shape in shapes :
8276 all_configs .append (
@@ -98,7 +92,7 @@ def get_single_backend_fn(backend: str):
9892 if backend == "torch_symm_mem" :
9993 return torch_symm_mem_gemm_rs
10094 if backend == "triton" :
101- return kraken .reduce_scatter_fusion .gemm_reduce_scatter
95+ return kraken .fused .gemm_reduce_scatter
10296 raise NotImplementedError (backend )
10397
10498
@@ -181,9 +175,7 @@ def main(args):
181175 torch .manual_seed (42 + local_rank )
182176
183177 results = []
184- configs = generate_experiment_configs (
185- args .dtype , args .M , args .N , args .K , args .backend , device
186- )
178+ configs = generate_experiment_configs (args .dtype , args .shape , args .backend , device )
187179 for config in configs :
188180 results .append (
189181 Experiment (
@@ -201,7 +193,7 @@ def shape_input_type(s):
201193 M , N , K = map (int , s .split ("," ))
202194 return M , N , K
203195 except Exception as e :
204- raise argparse .ArgumentTypeError ("Heads must be Hq,Hkv " ) from e
196+ raise argparse .ArgumentTypeError ("Shape must be M, N, K " ) from e
205197
206198
207199if __name__ == "__main__" :
@@ -233,27 +225,15 @@ def shape_input_type(s):
233225 )
234226
235227 parser .add_argument (
236- "-M" ,
237- type = shape_input_type ,
238- nargs = "+" ,
239- default = [2 ** x for x in range (7 , 11 )],
240- help = "matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)" ,
241- )
242-
243- parser .add_argument (
244- "-N" ,
228+ "--shape" ,
245229 type = shape_input_type ,
246230 nargs = "+" ,
247- default = [6656 ],
248- help = "matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)" ,
249- )
250-
251- parser .add_argument (
252- "-K" ,
253- type = shape_input_type ,
254- nargs = "+" ,
255- default = [2 ** x for x in range (12 , 16 )],
256- help = "matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)" ,
231+ default = [
232+ (m , 6656 , k )
233+ for m in [2 ** x for x in range (7 , 11 )]
234+ for k in [2 ** x for x in range (12 , 16 )]
235+ ],
236+ help = "matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)" ,
257237 )
258238
259239 parser .add_argument ("-dtype" , type = str , help = "dtype" , default = "float32" )
0 commit comments