Skip to content

Commit

Permalink
refactor(ivy, torch-frontend): Extends ivy.multi_head_attention and s…
Browse files Browse the repository at this point in the history
…hortens torch_frontend.multi_head_attention_forward (ivy-llc#23131)
  • Loading branch information
AnnaTz authored and druvdub committed Oct 14, 2023
1 parent effbdb9 commit fdab979
Show file tree
Hide file tree
Showing 7 changed files with 590 additions and 481 deletions.
12 changes: 12 additions & 0 deletions ivy/data_classes/array/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,12 @@ def multi_head_attention(
in_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
out_proj_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
is_causal: bool = False,
key_padding_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
bias_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
bias_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
static_k: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
static_v: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
add_zero_attn: bool = False,
return_attention_weights: bool = False,
average_attention_weights: bool = True,
dropout: float = 0.0,
Expand All @@ -430,6 +436,12 @@ def multi_head_attention(
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal,
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=return_attention_weights,
average_attention_weights=average_attention_weights,
dropout=dropout,
Expand Down
28 changes: 28 additions & 0 deletions ivy/data_classes/container/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,14 @@ def _static_multi_head_attention(
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
is_causal: Union[bool, ivy.Container] = False,
key_padding_mask: Optional[
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
bias_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
bias_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
add_zero_attn: Union[bool, ivy.Container] = False,
return_attention_weights: Union[bool, ivy.Container] = False,
average_attention_weights: Union[bool, ivy.Container] = True,
dropout: Union[float, ivy.Container] = 0.0,
Expand All @@ -1081,6 +1089,12 @@ def _static_multi_head_attention(
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal,
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=return_attention_weights,
average_attention_weights=average_attention_weights,
dropout=dropout,
Expand Down Expand Up @@ -1123,6 +1137,14 @@ def multi_head_attention(
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
is_causal: Union[bool, ivy.Container] = False,
key_padding_mask: Optional[
Union[ivy.Array, ivy.NativeArray, ivy.Container]
] = None,
bias_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
bias_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_k: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
static_v: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
add_zero_attn: Union[bool, ivy.Container] = False,
return_attention_weights: Union[bool, ivy.Container] = False,
average_attention_weights: Union[bool, ivy.Container] = True,
dropout: Union[float, ivy.Container] = 0.0,
Expand All @@ -1148,6 +1170,12 @@ def multi_head_attention(
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal,
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=return_attention_weights,
average_attention_weights=average_attention_weights,
dropout=dropout,
Expand Down
120 changes: 119 additions & 1 deletion ivy/functional/backends/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,129 @@

# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from . import backend_version
from ivy.functional.ivy.layers import _handle_padding, _deconv_length


@with_supported_dtypes(
{"2.0.1 and below": ("float32", "float64", "complex")},
backend_version,
)
def multi_head_attention(
query: torch.Tensor,
/,
*,
key: torch.Tensor = None,
value: torch.Tensor = None,
batch_first: bool = True,
num_heads: Optional[int] = 8,
scale: Optional[float] = None,
attention_mask: torch.Tensor = None,
in_proj_weights: torch.Tensor = None,
q_proj_weights: torch.Tensor = None,
k_proj_weights: torch.Tensor = None,
v_proj_weights: torch.Tensor = None,
out_proj_weights: torch.Tensor = None,
in_proj_bias: torch.Tensor = None,
out_proj_bias: torch.Tensor = None,
is_causal: Optional[bool] = False,
key_padding_mask: Optional[torch.Tensor] = None,
bias_k: Optional[torch.Tensor] = None,
bias_v: Optional[torch.Tensor] = None,
static_k: Optional[torch.Tensor] = None,
static_v: Optional[torch.Tensor] = None,
add_zero_attn: bool = False,
return_attention_weights: Optional[bool] = False,
average_attention_weights: Optional[bool] = True,
dropout: Optional[float] = 0.0,
training: Optional[bool] = False,
out: torch.Tensor = None,
) -> torch.Tensor:
if key is None and value is None:
key = value = query
emb_dim = _get_embed_dim(
in_proj_weights,
q_proj_weights,
k_proj_weights,
v_proj_weights,
query,
)[1]
num_dims = query.ndim
if num_dims == 3 and batch_first:
query, key, value = [torch.swapaxes(x, 0, 1) for x in [query, key, value]]
ret = torch.nn.functional.multi_head_attention_forward(
query,
key,
value,
emb_dim,
num_heads,
in_proj_weights,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout,
out_proj_weights,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=return_attention_weights,
attn_mask=attention_mask,
use_separate_proj_weight=not ivy.exists(in_proj_weights),
q_proj_weight=q_proj_weights,
k_proj_weight=k_proj_weights,
v_proj_weight=v_proj_weights,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attention_weights,
is_causal=is_causal,
)
ret = list(ret) if isinstance(ret, tuple) else [ret]
if num_dims == 3 and batch_first:
ret[0] = ret[0].swapaxes(0, 1)
if return_attention_weights:
return tuple(ret)
return ret[0]


multi_head_attention.partial_mixed_handler = (
lambda *args, scale=None, out_proj_weights=None, is_causal=False, attention_mask=None, return_attention_weights=False, in_proj_weights=None, q_proj_weights=None, k_proj_weights=None, v_proj_weights=None, **kwargs: not ivy.exists(
scale
)
and ivy.exists(out_proj_weights)
and (not is_causal or ivy.exists(attention_mask))
and (not is_causal or not return_attention_weights)
and (
ivy.exists(in_proj_weights)
or all(
[ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]]
)
)
and len(
set(
_get_embed_dim(
in_proj_weights, q_proj_weights, k_proj_weights, v_proj_weights, args[0]
)
)
)
== 1
)


def _get_embed_dim(
in_proj_weights, q_proj_weights, k_proj_weights, v_proj_weights, query
):
pre_embed_dim = query.shape[-1]
if ivy.exists(in_proj_weights):
embed_dim = in_proj_weights.shape[0] / 3
elif all([ivy.exists(x) for x in [q_proj_weights, k_proj_weights, v_proj_weights]]):
embed_dim = q_proj_weights.shape[0]
else:
embed_dim = None
return pre_embed_dim, embed_dim


@with_unsupported_dtypes(
{"2.0.1 and below": ("float16", "bfloat16", "complex")},
backend_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,143 +266,36 @@ def multi_head_attention_forward(
average_attn_weights=True,
is_causal=False,
):
# q/k/v shape: (seq_len, batch_size, embed_dim)
seq_len, batch_size, embed_dim = query.shape
embed_dim = query.shape[-1]
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert key.shape == value.shape

head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim needs to be divisible by heads"
scale = ivy.sqrt(head_dim)

if use_separate_proj_weight:
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's"
f" {value.shape[:2]}"
)
else:
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"

if is_causal and key_padding_mask is None and not need_weights:
mask = ivy.tril(ivy.ones((seq_len, seq_len), dtype=query.dtype), k=0)
attn_mask = ivy.zeros((seq_len, seq_len), dtype=query.dtype)
attn_mask = ivy.where(mask == 0.0, float("-inf"), 0)

if in_proj_bias is None:
q_bias, k_bias, v_bias = None, None, None
else:
q_bias, k_bias, v_bias = ivy.split(in_proj_bias, num_or_size_splits=3)

if not use_separate_proj_weight:
q_proj_weight, k_proj_weight, v_proj_weight = ivy.split(
in_proj_weight, num_or_size_splits=3
)

q = ivy.linear(query, q_proj_weight, bias=q_bias)
k = ivy.linear(key, k_proj_weight, bias=k_bias)
v = ivy.linear(value, v_proj_weight, bias=v_bias)

if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = ivy.concat([k, ivy.tile(bias_k, (1, batch_size, 1))])
v = ivy.concat([v, ivy.tile(bias_v, (1, batch_size, 1))])
if attn_mask is not None:
attn_mask = ivy.concat(
[attn_mask, ivy.zeros((attn_mask.shape[0], 1), dtype=attn_mask.dtype)],
axis=1,
)
if key_padding_mask is not None:
key_padding_mask = ivy.concat(
[
key_padding_mask,
ivy.zeros(
(key_padding_mask.shape[0], 1), dtype=key_padding_mask.dtype
).bool(),
],
axis=1,
)

q = ivy.swapaxes(q.reshape((q.shape[0], batch_size * num_heads, head_dim)), 0, 1)

if static_k is None:
k = ivy.swapaxes(
k.reshape((k.shape[0], batch_size * num_heads, head_dim)), 0, 1
)
else:
assert static_k.shape[0] == batch_size * num_heads, (
f"expecting static_k.shape[0] of {batch_size * num_heads}, but got"
f" {static_k.shape[0]}"
)
assert (
static_k.shape[2] == head_dim
), f"expecting static_k.shape[2] of {head_dim}, but got {static_k.shape[2]}"
k = static_k

if static_v is None:
v = ivy.swapaxes(
v.reshape((v.shape[0], batch_size * num_heads, head_dim)), 0, 1
)
else:
assert static_v.shape[0] == batch_size * num_heads, (
f"expecting static_v.shape[0] of {batch_size * num_heads}, but got"
f" {static_v.shape[0]}"
)
assert (
static_v.shape[2] == head_dim
), f"expecting static_v.shape[2] of {head_dim}, but got {static_v.shape[2]}"
v = static_v

# TODO add_zero_attn doesn't work for all cases
# fix this and add test cases (by changing to add_zero_attn=st.booleans())
if add_zero_attn:
zero_attn_shape = (batch_size * num_heads, 1, head_dim)
k = ivy.concat([k, ivy.zeros(zero_attn_shape, dtype=k.dtype)], axis=1)
v = ivy.concat([v, ivy.zeros(zero_attn_shape, dtype=v.dtype)], axis=1)
if attn_mask is not None:
attn_mask = ivy.pad(attn_mask, [(0, 0), (0, 1)])
if key_padding_mask is not None:
key_padding_mask = ivy.pad(key_padding_mask, [(0, 0), (0, 1)])

src_len = k.shape[1]
attn_weights = ivy.matmul(q, ivy.swapaxes(k, 1, 2))
assert list(attn_weights.shape) == [batch_size * num_heads, seq_len, src_len]

attn_weights = attn_weights / scale

if attn_mask is not None:
attn_mask = ivy.expand_dims(attn_mask, axis=0)
attn_weights += attn_mask

if key_padding_mask is not None:
key_padding_mask = ivy.expand_dims(
ivy.expand_dims(key_padding_mask, axis=1), axis=2
)
attn_weights = attn_weights.reshape((batch_size, num_heads, seq_len, src_len))
attn_weights = ivy.where(key_padding_mask < 0.0, float("-inf"), attn_weights)
attn_weights = attn_weights.reshape((batch_size * num_heads, seq_len, src_len))

attn_weights = ivy.softmax(attn_weights, axis=-1)
attn_weights = ivy.dropout(attn_weights, dropout_p, training=training)

attn_output = ivy.matmul(attn_weights, v)
assert list(attn_output.shape) == [batch_size * num_heads, seq_len, head_dim]
attn_output = ivy.swapaxes(attn_output, 0, 1).reshape(
(seq_len, batch_size, embed_dim)
return ivy.multi_head_attention(
query,
key=key,
value=value,
batch_first=False,
num_heads=num_heads,
attention_mask=attn_mask,
in_proj_weights=in_proj_weight if not use_separate_proj_weight else None,
q_proj_weights=q_proj_weight,
k_proj_weights=k_proj_weight,
v_proj_weights=v_proj_weight,
out_proj_weights=out_proj_weight,
in_proj_bias=in_proj_bias,
out_proj_bias=out_proj_bias,
is_causal=is_causal and not (need_weights or key_padding_mask is not None),
key_padding_mask=key_padding_mask,
bias_k=bias_k,
bias_v=bias_v,
static_k=static_k,
static_v=static_v,
add_zero_attn=add_zero_attn,
return_attention_weights=need_weights,
average_attention_weights=average_attn_weights,
dropout=dropout_p,
training=training,
)
attn_output = ivy.linear(attn_output, out_proj_weight, bias=out_proj_bias)

if need_weights:
attn_weights = attn_weights.reshape((batch_size, num_heads, seq_len, src_len))
if average_attn_weights:
attn_weights = ivy.sum(attn_weights, axis=1) / num_heads
return (attn_output, attn_weights)
else:
return (attn_output,)


@to_ivy_arrays_and_back
Expand Down
Loading

0 comments on commit fdab979

Please sign in to comment.