Skip to content

Commit c02592d

Browse files
authored
[None][autodeploy] Add group attention pattern for solar-pro-preview (#7054)
Signed-off-by: Frida Hou <[email protected]>
1 parent 0e30fe4 commit c02592d

File tree

1 file changed

+22
-0
lines changed
  • tensorrt_llm/_torch/auto_deploy/transform/library

1 file changed

+22
-0
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,16 @@ def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale):
324324
)
325325

326326

327+
def _grouped_attn_pattern_5(q, k, v, n_rep, attn_mask):
328+
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
329+
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
330+
return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, attn_mask)
331+
332+
333+
def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask):
334+
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask)
335+
336+
327337
@TransformRegistry.register("match_repeat_kv")
328338
class MatchRepeatKV(BaseTransform):
329339
"""
@@ -423,6 +433,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
423433

424434
dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale]
425435
dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale]
436+
dummy_args_3 = [q, k1, v1, n_rep, attn_mask]
426437

427438
register_ad_pattern(
428439
search_fn=_grouped_attn_pattern_1,
@@ -458,8 +469,19 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
458469
"dropout_p": dropout,
459470
},
460471
)
472+
register_ad_pattern(
473+
search_fn=_grouped_attn_pattern_5,
474+
replace_fn=_grouped_attn_replacement_5,
475+
patterns=patterns,
476+
dummy_args=dummy_args_3,
477+
scalar_workaround={"n_rep": n_rep},
478+
)
461479

462480
num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention)
481+
if num_grouped_patterns == 0:
482+
ad_logger.warning(
483+
"Fail to find any Group Attention Pattern, output or performance may be incorrect"
484+
)
463485

464486
info = TransformInfo(
465487
skipped=False,

0 commit comments

Comments
 (0)