Skip to content
Merged
4 changes: 2 additions & 2 deletions keras_nlp/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
freq_range = ops.arange(0, rotary_dim, 2)
freq_range = ops.cast(freq_range, self.compute_dtype)
freq_range = freq_range / ops.cast(
self.scaling_factor, self.compute_dtype
Expand All @@ -107,7 +107,7 @@ def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):
** (freq_range / ops.cast(rotary_dim, self.compute_dtype))
)
seq_len = ops.shape(x)[self.sequence_axis]
tensor = ops.arange(seq_len, dtype="float32") + start_index
tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index
tensor = ops.cast(tensor, dtype=inverse_freq.dtype)
freq = ops.einsum("i, j -> ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=self.feature_axis)
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
GPTNeoXPreprocessor,
)
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
293 changes: 293 additions & 0 deletions keras_nlp/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


# This is just a self-attention layer in Mistral. But it can be generalized
# to use the `keras_nlp.layers.CachedMultiHeadAttention` API. Since this layer
# implements grouped-query attention and sliding window attention, it might be
# useful outside of Mistral itself.
# TODO(tirthasheshpatel): Generalize the attention layer
# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer
# TODO(tirthasheshpatel): Use flash attention
class CachedMistralAttention(keras.layers.Layer):
"""A cached grounded query attention layer with sliding window."""

def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=512,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self._num_query_heads = num_query_heads
self._num_key_value_heads = num_key_value_heads
self._sliding_window = sliding_window
self._dropout = dropout

self._num_key_value_groups = num_query_heads // num_key_value_heads
self._rope_max_wavelength = rope_max_wavelength

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self._rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
self._head_dim = self._hidden_dim // self._num_query_heads

self._query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self._num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.compute_dtype,
name="query",
)
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self._num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.compute_dtype,
name="key",
)
self._key_dense.build(inputs_shape)

self._value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self._num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.compute_dtype,
name="value",
)
self._value_dense.build(inputs_shape)

self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax")

self._dropout_layer = keras.layers.Dropout(
rate=self._dropout, dtype=self.compute_dtype
)

self._output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.compute_dtype,
name="attention_output",
)
self._output_dense.build(
(None, None, self._num_query_heads, self._head_dim)
)

self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self._rope_max_wavelength,
scaling_factor=self._rope_scaling_factor,
dtype=self.compute_dtype,
)

self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
Comment on lines +135 to +136
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Right now, caching doesn't work when the sequence length is greater than the sliding window.

Can be addressed as a follow-up when adding the Generator model; shouldn't be a blocker here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. If the upstream version is not solving this correctly, let's not worry too much about this.

training=None,
):
seq_len = ops.shape(hidden_states)[1]
start_index = (
cache_update_index if cache_update_index is not None else 0
)
# If `cache_update_index` is a tensor, RotaryEmbedding expects it
# to have dtype `self.compute_dtype`.
start_index = ops.cast(
start_index, self.rotary_embedding_layer.compute_dtype
)

query = self._query_dense(hidden_states)

# Note that the original PyTorch implementation uses
# view_as_complex/view_as_real while we use split/concatenate to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this a bit more? Why do we need to consider complex numbers here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the mistral source for computing frequencies (same as Llama 2) and computing the embeddings (same as Llama 2 too)

The frequencies and inputs are treated as complex numbers and the computation follows the "Theoretical Explanation" section in the paper.

PyTorch's view_as_complex is used to convert the tensors to complex numbers which reshapes the inputs to shape (*x.shape[:-1], x.shape[-1] // 2, 2) and treats each pair of elements in axis -1 as a (real, complex) pair. RotaryEmbedding uses ops.split(x, 2) to convert the inputs to a complex representation (after splitting, the first half of the inputs become the real part and the other half becomes the complex part).

This is the only fundamental difference in both the computations. We can get the same results if we shuffle the inputs such that the alternate elements get moved to the end of the tensor. Hence, the x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1) bit before passing the inputs to the rotary embedding layer.

The reverse transformation exactly mirrors/undoes what we did above.

Code demonstration of the above explaination
import torch
import numpy as np
from keras import ops

def _reshape_for_broadcast(freqs_cis, x):
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


# Llama's version of rotary embeddings
def apply_rotary_emb(
    xq,
    freqs_cis,
):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq)


# Our version of the same computation.
# With transformations to match the `apply_rotary_emb` function above.
def apply_rotary_pos_emb(tensor, cos_emb, sin_emb):
    tensor = ops.concatenate((tensor[..., ::2], tensor[..., 1::2]), axis=-1)
    x1, x2 = ops.split(tensor, 2, axis=-1)
    half_rot_tensor = ops.concatenate((-x2, x1), axis=-1)
    res = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
    return res

x = np.random.standard_normal((1,2,1,16))
cos_emb = np.random.standard_normal((2,8))
sin_emb = np.random.standard_normal((2,8))

print(x)

print(ops.concatenate((x[..., ::2], x[..., 1::2]), axis=-1))
print(np.split(np.concatenate((x[..., ::2], x[..., 1::2]), axis=-1), 2, axis=-1))

print(torch.view_as_complex(torch.tensor(x).reshape(*x.shape[:-1], -1, 2)))
print(apply_rotary_emb(torch.tensor(x), torch.tensor(cos_emb + sin_emb * 1.0j)))

y = apply_rotary_pos_emb(
    x,
    np.concatenate([cos_emb[None, :, None, :]]*2, axis=-1),
    np.concatenate([sin_emb[None, :, None, :]]*2, axis=-1)
)

print(ops.reshape(ops.stack(ops.split(y, 2, axis=-1), axis=-1), (y.shape[0], y.shape[1], y.shape[2], -1)))

A bit complicated but it should be possible to achieve the same behavior by shuffling the weights using the same transformations. I believe that's what the huggingface folks have done which is why this isn't required in the Llama backbone PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

by shuffling the weights using the same transformations

Shuffling what weights?

At the highest level, we should just consider whether we should pull this into the lower level RotaryEmbedding layer. We want it to be useful for the most common use cases of rotary embeddings.

# convert to/from complex numbers. The transformations below make
# the rope computation numerically equivalent to the original
# implementation.
def _mistral_rope(x):
x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1)
x = self.rotary_embedding_layer(x, start_index=start_index)
x = ops.reshape(
ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
)
return x

# Compute RoPE for queries
query = _mistral_rope(query)

def _compute_key_value(x):
key, value = self._key_dense(x), self._value_dense(x)
key = _mistral_rope(key)
return key, value

if cache is not None:
cache_k = cache[:, 0, ...]
cache_v = cache[:, 1, ...]

if cache_update_index is not None:
# Compute the new keys and values
key, value = _compute_key_value(hidden_states)

# Cache is a rotating buffer, we want to warp around if
# the sequence length exceeds the sliding window.
update_end_index = (
cache_update_index + seq_len - 1
) % self._sliding_window + 1
update_end_index = ops.cast(update_end_index, "int32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a general note, torch and jax like int32 on gpu, but tensorflow has limited op support with int32 (and does better with int64). We probably don't have coverage for this code path on GPU on TF with an accelerator yet, but might come up down the line.

cache_update_index = cache_update_index % self._sliding_window
update_start_index = ops.cond(
update_end_index > cache_update_index,
lambda: ops.cast(cache_update_index, "int32"),
lambda: ops.cast(0, "int32"),
)
# Also note that the update step below assumes that the
# sequence length is always one when `cache_update_index != 0`.
# This is necessary to support XLA compilation. Ideally, we
# would want to use
# `key[:, -(update_end_index - update_start_index):, ...]`
# as the update but updating using a dynamic slice gives an
# XLA compilation error in TensorFlow.
# Passing a sequence of length > 1 with cache update might give
# incorrect results (since there is no way to determine how
# many most recent tokens are to be saved if the tokens exceed
# the sliding window length).
cache_k = ops.slice_update(
cache_k,
[0, update_start_index, 0, 0],
# We slice the keys and values since if the user has passed
# a sequence of length > `self._sliding_window`. We want to
# prefill the cache using just the most recent values in the
# sliding window.
ops.cast(
key[:, -self._sliding_window :, ...], cache_k.dtype
),
)
cache_v = ops.slice_update(
cache_v,
[0, update_start_index, 0, 0],
ops.cast(
value[:, -self._sliding_window :, ...], cache_v.dtype
),
)
cache = ops.stack([cache_k, cache_v], axis=1)

# Get the required keys and values from the cache.
# Since we expect the user to pass a fixed-size cache, we just
# pick the first few slices up-to and including the newly computed
# keys and values.
cache_k = cache_k[:, :update_end_index, ...]
cache_v = cache_v[:, :update_end_index, ...]
Comment on lines +227 to +228
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX fails here if cache_update_index is a traced JAX array. But the value of cache_update_index should be known at each step. I think the right fix here is to make sure that the GenerateTask model passes concrete values here. Otherwise, it would be pretty tricky to make sliding window attention work in JAX.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this looks unsupported by XLA today, as this would involve dynamic shapes in a compiled while_loop. Let's work on a fix as a follow up.


key = ops.cast(cache_k, dtype=self.compute_dtype)
value = ops.cast(cache_v, dtype=self.compute_dtype)
else:
# Compute keys and values
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self._dropout_layer(
attention_output, training=training
)

attention_output = self._output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
)
return self._softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum(self._dot_product_equation, key, query)

norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))

attention_scores = attention_scores / norm_factor

attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self._num_query_heads,
"num_key_value_heads": self._num_key_value_heads,
"rope_max_wavelength": self._rope_max_wavelength,
"rope_scaling_factor": self._rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self._sliding_window,
"dropout": self._dropout,
}
)
return config
Loading