Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
adaptive_avg_pool2d,
adaptive_avg_pool3d,
attention,
attention_bias,
attention_var_len,
avg_pool1d,
avg_pool2d,
Expand Down
97 changes: 97 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,103 @@ def attention(
) # type: ignore


def attention_bias(
query: Expr,
key: Expr,
value: Expr,
bias: Optional[Expr] = None,
scale: Optional[FloatImm] = None,
causal_mask: Optional[str] = None,
window_size: Optional[int] = None,
) -> Expr:
r"""Computes fused multi head attention.

IRModule.script() transforms attention op to attention_bias which is incompatible
with TVMScript Parser.
The function makes TVMScript's print compatible with TVMScript's parser.

All input tensors are of 4-D tensors with BSNH layout.

.. math::
FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V

.. note::
The input tensor is required to have float16 dtype

Parameters
----------
query: relax.Expr
The input query to the operator. The layout of the input query should be
(batch_size, seq_len, num_head, head_dim).

key: relax.Expr
The input key to the operator. The layout of the input key should be
(batch_size, seq_len_kv, num_head, head_dim).

value: relax.Expr
The input value to the operator. The layout of the input value should be
(batch_size, seq_len_kv, num_head, head_dim_v).

bias: Optional[Expr]
The optional attention bias to the operator. The layout of the attention bias should be
a 4-D tensor ending with seq_len_kv, and broadcastable to
(batch_size, num_head, seq_len, seq_len_kv).

scale: Optional[float]
The scale value to be applied to the attention score, by default 1 / sqrt(head_dim).

causal_mask: Optional[str]
The optional causal mask, i.e. 'TopLeft' and 'BottomRight'.
For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`,
while for 'BottomRight', the mask matrix is as `np.tril(*, k=abs(seq_len - seq_len_kv))`
For example, with seq_len = 4, seq_len_kv = 2,
mask for 'TopLeft':

.. code:: python

[[1, 0],
[1, 1],
[1, 1],
[1, 1]]

mask for 'BottomRight':

.. code:: python

[[1, 1],
[1, 1],
[1, 1],
[1, 1]]

with seq_len = 2, seq_len_kv = 4,
mask for 'TopLeft':

.. code:: python

[[1, 0, 0, 0],
[1, 1, 0, 0]]

mask for 'BottomRight':

.. code:: python

[[1, 1, 1, 0],
[1, 1, 1, 1]]

window_size: Optional[int]
The size of the window for sliding-window attention.

Returns
-------
result : relax.Expr
The computed result. The layout of the output should be
(batch_size, seq_len, num_head, head_dim_v).
"""
return _ffi_api.attention(
query, key, value, bias, scale, causal_mask, window_size
) # type: ignore


def attention_var_len(
queries: Expr,
keys: Expr,
Expand Down