1616from onnxscript import script
1717from onnxscript .onnx_opset import opset18 as op
1818from onnxscript .onnx_types import FLOAT
19+ from onnxscript .rewriter .ort_fusions ._test_utils import assert_allclose , ort_run
1920from onnxscript .rewriter .ort_fusions .sdpa import fuse_sdpa
21+ from onnxscript .rewriter .ort_fusions .sdpa_via_mha import replace_sdpa_by_mha
2022
2123B = 2 # batch size
2224N = 4 # number of heads
@@ -190,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
190192
191193
192194@script ()
193- def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
195+ def _masked_custom_scale_pre_div_sdpa_script (query , key , value , mask ):
194196 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
195197 divisor = op .Constant (value_float = SQRT_CUSTOM_DIV_SCALE_FACTOR )
196198 scaled_query = op .Div (query , divisor )
@@ -203,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
203205
204206
205207@script ()
206- def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
208+ def _masked_custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
207209 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
208210 multiplier = op .Constant (value_float = SQRT_CUSTOM_MUL_SCALE_FACTOR )
209211 scaled_query = op .Mul (query , multiplier )
@@ -216,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
216218
217219
218220@script ()
219- def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
221+ def _masked_custom_scale_post_div_sdpa_script (query , key , value , mask ):
220222 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
221223 divisor = op .Constant (value_float = CUSTOM_DIV_SCALE_FACTOR )
222224 attn_score = op .MatMul (query , key_transposed )
@@ -228,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
228230
229231
230232@script ()
231- def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
233+ def _masked_custom_scale_post_mul_sdpa_script (query , key , value , mask ):
232234 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
233235 multiplier = op .Constant (value_float = CUSTOM_MUL_SCALE_FACTOR )
234236 attn_score = op .MatMul (query , key_transposed )
@@ -240,15 +242,19 @@ def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
240242
241243
242244class SDPATestCase :
243- def __init__ (self , script_func ):
245+ def __init__ (self , script_func , * , with_mask ):
244246 self .script_func = script_func
247+ self .with_mask = with_mask
245248
246249 def get_onnx_model (self ):
247250 if not hasattr (self , "_onnx_model" ):
248251 qkv_type = FLOAT [B , N , S , H ]
249252 mask_type = FLOAT [B , N , S , S ]
253+ input_types = [qkv_type , qkv_type , qkv_type ]
254+ if self .with_mask :
255+ input_types .append (mask_type )
250256 model_proto = self .script_func .to_model_proto (
251- input_types = [ qkv_type , qkv_type , qkv_type , mask_type ] , output_types = [qkv_type ]
257+ input_types = input_types , output_types = [qkv_type ]
252258 )
253259 self ._onnx_model = ir .serde .deserialize_model (model_proto )
254260 return self ._onnx_model
@@ -259,6 +265,35 @@ def get_ort_inputs(self):
259265 "query" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
260266 "key" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
261267 "value" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
268+ }
269+ if self .with_mask :
270+ inputs ["mask" ] = numpy .random .rand (B , N , S , S ).astype (numpy .float32 )
271+ self ._ort_inputs = inputs
272+ return self ._ort_inputs
273+
274+
275+ class InvalidSDPATestCase :
276+ def __init__ (self , script_func ):
277+ self .script_func = script_func
278+
279+ def get_onnx_model (self ):
280+ if not hasattr (self , "_onnx_model" ):
281+ qk_type = FLOAT [B , N , S , H ]
282+ # We broadcast value in the batch dimension, which is not supported by SDPA fusion
283+ v_type = FLOAT [1 , N , S , H ]
284+ mask_type = FLOAT [B , N , S , S ]
285+ model_proto = self .script_func .to_model_proto (
286+ input_types = [qk_type , qk_type , v_type , mask_type ], output_types = [qk_type ]
287+ )
288+ self ._onnx_model = ir .serde .deserialize_model (model_proto )
289+ return self ._onnx_model
290+
291+ def get_ort_inputs (self ):
292+ if not hasattr (self , "_ort_inputs" ):
293+ inputs = {
294+ "query" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
295+ "key" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
296+ "value" : numpy .random .rand (1 , N , S , H ).astype (numpy .float32 ),
262297 "mask" : numpy .random .rand (B , N , S , S ).astype (numpy .float32 ),
263298 }
264299 self ._ort_inputs = inputs
@@ -296,35 +331,35 @@ def get_ort_inputs(self):
296331class TestSDPAFusion (unittest .TestCase ):
297332 @parameterized .parameterized .expand (
298333 [
299- ("unmasked_pre_div " , _unmasked_pre_div_sdpa_script ),
300- ("unmasked_pre_mul " , _unmasked_pre_mul_sdpa_script ),
301- ("unmasked_post_div " , _unmasked_post_div_sdpa_script ),
302- ("unmasked_post_mul " , _unmasked_post_mul_sdpa_script ),
303- ("pre_div " , _masked_pre_div_sdpa_script ),
304- ("pre_mul " , _masked_pre_mul_sdpa_script ),
305- ("post_div " , _masked_post_div_sdpa_script ),
306- ("post_mul " , _masked_post_mul_sdpa_script ),
334+ ("pre_div " , _unmasked_pre_div_sdpa_script ),
335+ ("pre_mul " , _unmasked_pre_mul_sdpa_script ),
336+ ("post_div " , _unmasked_post_div_sdpa_script ),
337+ ("post_mul " , _unmasked_post_mul_sdpa_script ),
338+ ("masked_pre_div " , _masked_pre_div_sdpa_script ),
339+ ("masked_pre_mul " , _masked_pre_mul_sdpa_script ),
340+ ("masked_post_div " , _masked_post_div_sdpa_script ),
341+ ("masked_post_mul " , _masked_post_mul_sdpa_script ),
307342 ("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
308343 ("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
309344 ("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
310345 ("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
311- ("custom_scale_post_mul_masked " , _custom_scale_post_mul_sdpa_script ),
312- ("custom_scale_post_div_masked " , _custom_scale_post_div_sdpa_script ),
313- ("custom_scale_pre_mul_masked " , _custom_scale_pre_mul_sdpa_script ),
314- ("custom_scale_pre_div_masked " , _custom_scale_pre_div_sdpa_script ),
346+ ("masked_custom_scale_post_mul " , _masked_custom_scale_post_mul_sdpa_script ),
347+ ("masked_custom_scale_post_div " , _masked_custom_scale_post_div_sdpa_script ),
348+ ("masked_custom_scale_pre_mul " , _masked_custom_scale_pre_mul_sdpa_script ),
349+ ("masked_custom_scale_pre_div " , _masked_custom_scale_pre_div_sdpa_script ),
315350 (
316351 "_custom_multi_scale_pre_mul_sdpa_script" ,
317352 _custom_multi_scale_pre_mul_sdpa_script ,
318353 ),
319354 ]
320355 )
321356 def test_sdpa_fusion (self , name , script_func ):
322- test_case = SDPATestCase (script_func )
357+ test_case = SDPATestCase (script_func , with_mask = "masked" in name )
323358 model = test_case .get_onnx_model ()
324359 onnxscript .optimizer .optimize (model )
325360
326- # inputs = test_case.get_ort_inputs()
327- # original_outputs = ort_run("original", model, inputs)
361+ inputs = test_case .get_ort_inputs ()
362+ original_outputs = ort_run ("original" , model , inputs )
328363
329364 count = fuse_sdpa (model , debug = True )
330365 self .assertGreater (count , 0 )
@@ -347,8 +382,19 @@ def test_sdpa_fusion(self, name, script_func):
347382 # of scale_factor (is =default_scaling_factor)
348383 self .assertIsNone (sdpa_node .attributes .get ("scale" ))
349384
350- # new_outputs = ort_run("optimized", model, inputs)
351- # assert_allclose(new_outputs, original_outputs)
385+ replace_sdpa_by_mha (model , debug = True )
386+
387+ self .assertNotIn ("SDPA" , [n .op_type for n in model .graph ])
388+
389+ new_outputs = ort_run ("optimized" , model , inputs )
390+ assert_allclose (new_outputs , original_outputs )
391+
392+ def test_invalid_sdpa_fusion_value_batch_dim (self ):
393+ test_case = InvalidSDPATestCase (_masked_pre_mul_sdpa_script )
394+ model = test_case .get_onnx_model ()
395+ onnxscript .optimizer .optimize (model )
396+ count = fuse_sdpa (model )
397+ self .assertEqual (count , 0 )
352398
353399 def test_invalid_sdpa_fusion_value_batch_dim (self ):
354400 test_case = InvalidSDPATestCase (_masked_pre_mul_sdpa_script )
0 commit comments