@@ -312,3 +312,59 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
312312def _nn_dropout (bb : BlockBuilder , call : Call ) -> Expr :
313313 logging .info ("Dropout is handled by frontend translator at this moment and is not legalized." )
314314 return call
315+
316+
317+ def _te_attention (
318+ q : te .Tensor , k : te .Tensor , v : te .Tensor , bias : te .Tensor , scale : tir .FloatImm
319+ ) -> te .Tensor :
320+ batch_size , seq_len , num_head , head_dim = q .shape
321+ _ , seq_len_kv , _ , head_dim_v = v .shape
322+ q = topi .transpose (q , [0 , 2 , 1 , 3 ])
323+ k = topi .transpose (k , [0 , 2 , 1 , 3 ])
324+ v = topi .transpose (v , [0 , 2 , 1 , 3 ])
325+ q = topi .reshape (q , [batch_size * num_head , seq_len , head_dim ])
326+ k = topi .reshape (k , [batch_size * num_head , seq_len_kv , head_dim ])
327+ v = topi .reshape (v , [batch_size * num_head , seq_len_kv , head_dim_v ])
328+ p = topi .nn .batch_matmul (q , k )
329+ if scale is not None :
330+ p = topi .multiply (p , scale )
331+ else :
332+ p = topi .divide (p , tir .sqrt (tir .Cast (p .dtype , head_dim )))
333+ if bias is not None :
334+ p = topi .reshape (p , [batch_size , num_head , seq_len , seq_len_kv ])
335+ if len (bias .shape ) == 2 :
336+ bias = topi .reshape (bias , [batch_size , 1 , 1 , seq_len_kv ])
337+ elif len (bias .shape ) == 3 :
338+ bias = topi .reshape (bias , [batch_size , 1 , seq_len , seq_len_kv ])
339+ p = topi .add (p , bias )
340+ p = topi .reshape (p , [batch_size * num_head , seq_len , seq_len_kv ])
341+ s = topi .nn .softmax (p )
342+ o = topi .nn .batch_matmul (s , v , transpose_b = False )
343+ o = topi .reshape (o , [batch_size , num_head , seq_len , head_dim_v ])
344+ return topi .transpose (o , [0 , 2 , 1 , 3 ])
345+
346+
347+ @register_legalize ("relax.nn.attention" )
348+ def _nn_attention (bb : BlockBuilder , call : Call ) -> Expr :
349+ return bb .call_te (
350+ _te_attention ,
351+ call .args [0 ],
352+ call .args [1 ],
353+ call .args [2 ],
354+ None ,
355+ call .attrs .scale ,
356+ primfunc_name_hint = "attention" ,
357+ )
358+
359+
360+ @register_legalize ("relax.nn.attention_bias" )
361+ def _nn_attention_bias (bb : BlockBuilder , call : Call ) -> Expr :
362+ return bb .call_te (
363+ _te_attention ,
364+ call .args [0 ],
365+ call .args [1 ],
366+ call .args [2 ],
367+ call .args [3 ],
368+ call .attrs .scale ,
369+ primfunc_name_hint = "attention_bias" ,
370+ )
0 commit comments