Skip to content

Commit f6f6c1a

Browse files
committed
transpose the output
1 parent 6649037 commit f6f6c1a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,9 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
10151015
assert "float" in attn_mask.struct_info.dtype, msg
10161016

10171017
return self.block_builder.emit(
1018-
relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask)
1018+
transpose_S_H(
1019+
relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask)
1020+
)
10191021
)
10201022

10211023
def _unbind(self, node: fx.Node) -> relax.Var:

0 commit comments

Comments
 (0)