Skip to content

Commit 4b7d78d

Browse files
authored
[Relax] Handle dynamic arguments in legalization of nn.attention (#16592)
Prior to this commit, when using causal_mask="BottomRight" in `R.nn.attention`, the legalization would assume that the query and key/value sequence lengths were static integers. This commit updates the legalization to allow dynamic shapes.
1 parent 8f42597 commit 4b7d78d

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

python/tvm/relax/transform/legalize_ops/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _te_attention(
486486
if causal_mask == "TopLeft":
487487
offset = tir.IntImm("int32", 0)
488488
elif causal_mask == "BottomRight":
489-
offset = tir.IntImm("int32", abs(seq_len - seq_len_kv))
489+
offset = tir.abs(seq_len - seq_len_kv).astype("int32")
490490
else:
491491
raise NotImplementedError()
492492
p_masked = topi.trilu(p, k=offset, upper=False)

tests/python/relax/test_transform_legalize_ops_nn.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3270,6 +3270,30 @@ def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8)
32703270
tvm.ir.assert_structural_equal(mod, Expected)
32713271

32723272

3273+
def test_dynamic_attention():
3274+
"""The sequence lengths may be dynamic
3275+
3276+
In previous implementations, the `seq_len` and `seq_len_kv` were
3277+
assumed to be static integers, and produced an exception during
3278+
legalization.
3279+
"""
3280+
3281+
@tvm.script.ir_module
3282+
class Attention:
3283+
@R.function
3284+
def main(
3285+
q: R.Tensor((4, "seq_len", 32, 8), "float32"),
3286+
k: R.Tensor((4, "seq_len_kv", 32, 8), "float32"),
3287+
v: R.Tensor((4, "seq_len_kv", 32, 16), "float32"),
3288+
bias: R.Tensor((4, 32, "seq_len", "seq_len_kv"), "float32"),
3289+
):
3290+
scale = T.FloatImm("float32", 0.1)
3291+
gv = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="BottomRight")
3292+
return gv
3293+
3294+
LegalizeOps()(Attention)
3295+
3296+
32733297
def test_nll_loss():
32743298
# fmt: off
32753299
@tvm.script.ir_module

0 commit comments

Comments
 (0)