@@ -324,6 +324,16 @@ def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale):
324324 )
325325
326326
327+ def _grouped_attn_pattern_5 (q , k , v , n_rep , attn_mask ):
328+ k = torch .ops .auto_deploy .torch_attention_repeat_kv (k , n_rep )
329+ v = torch .ops .auto_deploy .torch_attention_repeat_kv (v , n_rep )
330+ return torch .ops .auto_deploy .torch_attention_sdpa .default (q , k , v , attn_mask )
331+
332+
333+ def _grouped_attn_replacement_5 (q , k , v , n_rep , attn_mask ):
334+ return torch .ops .auto_deploy .torch_attention_grouped_sdpa .default (q , k , v , attn_mask )
335+
336+
327337@TransformRegistry .register ("match_repeat_kv" )
328338class MatchRepeatKV (BaseTransform ):
329339 """
@@ -423,6 +433,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
423433
424434 dummy_args_1 = [q , k1 , v1 , n_rep , attn_mask , dropout , scale ]
425435 dummy_args_2 = [q , k1 , v1 , attn_mask , dropout , scale ]
436+ dummy_args_3 = [q , k1 , v1 , n_rep , attn_mask ]
426437
427438 register_ad_pattern (
428439 search_fn = _grouped_attn_pattern_1 ,
@@ -458,8 +469,19 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
458469 "dropout_p" : dropout ,
459470 },
460471 )
472+ register_ad_pattern (
473+ search_fn = _grouped_attn_pattern_5 ,
474+ replace_fn = _grouped_attn_replacement_5 ,
475+ patterns = patterns ,
476+ dummy_args = dummy_args_3 ,
477+ scalar_workaround = {"n_rep" : n_rep },
478+ )
461479
462480 num_grouped_patterns = _apply_pattern (gm , "Grouped Attention" , register_grouped_attention )
481+ if num_grouped_patterns == 0 :
482+ ad_logger .warning (
483+ "Fail to find any Group Attention Pattern, output or performance may be incorrect"
484+ )
463485
464486 info = TransformInfo (
465487 skipped = False ,
0 commit comments