Skip to content

Commit 1951b5c

Browse files
Add a Causal LM model for Mistral (#1429)
* Add Mistral Causal LM Preprocessor * Add the Causal LM for Mistral * Remove sliding window attention from Mistral's attention layer JAX complains about dynamic slicing when compiled with XLA. This is unavoidable since, at runtime, the slice of the current key/value array to use for that iteration is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround would lead to using dynamic shapes at some point. Hence, I had to remove this and instead use vanilla caching for now. For some reason, TensorFlow doesn't complain with XLA. I think this might be because TensorFlow is as stringent about statis shapes as JAX. In any case, adding sliding window attention that is XLA compatible is a story for the future. * Enable JIT compile in the Mistral LM model * Fix Mistral transformer decoder * Port the causal LM to the new infra * Fix a minor bug in sliding window attention caching * Fix a small bug in mistral transformer decoder * Remove the RoPE shenanigan in mistral attention layer * Address review comments and add mistral to the public API
1 parent 22c1e30 commit 1951b5c

File tree

7 files changed

+640
-80
lines changed

7 files changed

+640
-80
lines changed

keras_nlp/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@
9393
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
9494
from keras_nlp.models.llama.llama_backbone import LlamaBackbone
9595
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
96+
from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM
97+
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
98+
MistralCausalLMPreprocessor,
99+
)
100+
from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor
101+
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
96102
from keras_nlp.models.opt.opt_backbone import OPTBackbone
97103
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
98104
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (

keras_nlp/models/mistral/mistral_attention.py

Lines changed: 20 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def call(
141141
cache_update_index=None,
142142
training=None,
143143
):
144-
seq_len = ops.shape(hidden_states)[1]
145144
start_index = (
146145
cache_update_index if cache_update_index is not None else 0
147146
)
@@ -153,89 +152,34 @@ def call(
153152

154153
query = self._query_dense(hidden_states)
155154

156-
# Note that the original PyTorch implementation uses
157-
# view_as_complex/view_as_real while we use split/concatenate to
158-
# convert to/from complex numbers. The transformations below make
159-
# the rope computation numerically equivalent to the original
160-
# implementation.
161-
def _mistral_rope(x):
162-
x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1)
163-
x = self.rotary_embedding_layer(x, start_index=start_index)
164-
x = ops.reshape(
165-
ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
166-
)
167-
return x
168-
169155
# Compute RoPE for queries
170-
query = _mistral_rope(query)
156+
query = self.rotary_embedding_layer(query, start_index=start_index)
171157

172158
def _compute_key_value(x):
173159
key, value = self._key_dense(x), self._value_dense(x)
174-
key = _mistral_rope(key)
160+
# Compute RoPE for keys
161+
key = self.rotary_embedding_layer(key, start_index=start_index)
175162
return key, value
176163

177164
if cache is not None:
178-
cache_k = cache[:, 0, ...]
179-
cache_v = cache[:, 1, ...]
180-
165+
key_cache = cache[:, 0, ...]
166+
value_cache = cache[:, 1, ...]
167+
if cache_update_index is None:
168+
key = key_cache
169+
value = value_cache
170+
else:
171+
key_update, value_update = _compute_key_value(hidden_states)
172+
start = [0, cache_update_index, 0, 0]
173+
key = ops.slice_update(key_cache, start, key_update)
174+
value = ops.slice_update(value_cache, start, value_update)
175+
cache = ops.stack((key, value), axis=1)
176+
else:
181177
if cache_update_index is not None:
182-
# Compute the new keys and values
183-
key, value = _compute_key_value(hidden_states)
184-
185-
# Cache is a rotating buffer, we want to warp around if
186-
# the sequence length exceeds the sliding window.
187-
update_end_index = (
188-
cache_update_index + seq_len - 1
189-
) % self._sliding_window + 1
190-
update_end_index = ops.cast(update_end_index, "int32")
191-
cache_update_index = cache_update_index % self._sliding_window
192-
update_start_index = ops.cond(
193-
update_end_index > cache_update_index,
194-
lambda: ops.cast(cache_update_index, "int32"),
195-
lambda: ops.cast(0, "int32"),
196-
)
197-
# Also note that the update step below assumes that the
198-
# sequence length is always one when `cache_update_index != 0`.
199-
# This is necessary to support XLA compilation. Ideally, we
200-
# would want to use
201-
# `key[:, -(update_end_index - update_start_index):, ...]`
202-
# as the update but updating using a dynamic slice gives an
203-
# XLA compilation error in TensorFlow.
204-
# Passing a sequence of length > 1 with cache update might give
205-
# incorrect results (since there is no way to determine how
206-
# many most recent tokens are to be saved if the tokens exceed
207-
# the sliding window length).
208-
cache_k = ops.slice_update(
209-
cache_k,
210-
[0, update_start_index, 0, 0],
211-
# We slice the keys and values since if the user has passed
212-
# a sequence of length > `self._sliding_window`. We want to
213-
# prefill the cache using just the most recent values in the
214-
# sliding window.
215-
ops.cast(
216-
key[:, -self._sliding_window :, ...], cache_k.dtype
217-
),
178+
raise ValueError(
179+
"`cache_update_index` should not be set if `cache` is "
180+
f"`None`. Received: cache={cache}, "
181+
f"cache_update_index={cache_update_index}"
218182
)
219-
cache_v = ops.slice_update(
220-
cache_v,
221-
[0, update_start_index, 0, 0],
222-
ops.cast(
223-
value[:, -self._sliding_window :, ...], cache_v.dtype
224-
),
225-
)
226-
cache = ops.stack([cache_k, cache_v], axis=1)
227-
228-
# Get the required keys and values from the cache.
229-
# Since we expect the user to pass a fixed-size cache, we just
230-
# pick the first few slices up-to and including the newly computed
231-
# keys and values.
232-
cache_k = cache_k[:, :update_end_index, ...]
233-
cache_v = cache_v[:, :update_end_index, ...]
234-
235-
key = ops.cast(cache_k, dtype=self.compute_dtype)
236-
value = ops.cast(cache_v, dtype=self.compute_dtype)
237-
else:
238-
# Compute keys and values
239183
key, value = _compute_key_value(hidden_states)
240184

241185
# [batch_shape, seq_len, num_key_value_heads, head_dim]
@@ -265,7 +209,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
265209
return self._softmax(attention_scores)
266210

267211
def _compute_attention(self, query, key, value, attention_mask=None):
268-
attention_scores = ops.einsum(self._dot_product_equation, key, query)
212+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
269213

270214
norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
271215

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from keras_nlp.api_export import keras_nlp_export
16+
from keras_nlp.backend import keras
17+
from keras_nlp.backend import ops
18+
from keras_nlp.models.generative_task import GenerativeTask
19+
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
20+
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
21+
MistralCausalLMPreprocessor,
22+
)
23+
from keras_nlp.utils.python_utils import classproperty
24+
25+
26+
@keras_nlp_export("keras_nlp.models.MistralCausalLM")
27+
class MistralCausalLM(GenerativeTask):
28+
"""An end-to-end Mistral model for causal language modeling.
29+
30+
A causal language model (LM) predicts the next token based on previous
31+
tokens. This task setup can be used to train the model unsupervised on
32+
plain text input, or to autoregressively generate plain text similar to
33+
the data used for training. This task can be used for pre-training or
34+
fine-tuning a GPT-NeoX model, simply by calling `fit()`.
35+
36+
This model has a `generate()` method, which generates text based on a
37+
prompt. The generation strategy used is controlled by an additional
38+
`sampler` argument on `compile()`. You can recompile the model with
39+
different `keras_nlp.samplers` objects to control the generation. By
40+
default, `"top_k"` sampling will be used.
41+
42+
Args:
43+
backbone: A `keras_nlp.models.MistralBackbone` instance.
44+
preprocessor: A `keras_nlp.models.MistralCausalLMPreprocessor` or `None`.
45+
If `None`, this model will not apply preprocessing, and inputs
46+
should be preprocessed before calling the model.
47+
"""
48+
49+
def __init__(self, backbone, preprocessor=None, **kwargs):
50+
# === Layers ===
51+
self.backbone = backbone
52+
self.preprocessor = preprocessor
53+
54+
# === Functional Model ===
55+
inputs = backbone.inputs
56+
hidden_states = backbone(inputs)
57+
outputs = backbone.token_embedding(hidden_states, reverse=True)
58+
super().__init__(
59+
inputs=inputs,
60+
outputs=outputs,
61+
**kwargs,
62+
)
63+
64+
# === Default compilation ===
65+
self.compile(
66+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
67+
optimizer=keras.optimizers.Adam(2e-5),
68+
metrics=[keras.metrics.SparseCategoricalAccuracy()],
69+
jit_compile=True,
70+
)
71+
72+
@classproperty
73+
def backbone_cls(cls):
74+
return MistralBackbone
75+
76+
@classproperty
77+
def preprocessor_cls(cls):
78+
return MistralCausalLMPreprocessor
79+
80+
def call_with_cache(
81+
self,
82+
token_ids,
83+
cache,
84+
cache_update_index,
85+
):
86+
"""Forward pass of `MistralCausalLM` with cache.
87+
88+
`call_with_cache` adds an additional forward pass for the model for
89+
autoregressive inference. Unlike calling the model directly, this method
90+
allows caching previous key/value Tensors in multi-head attention layer,
91+
and avoids recomputing the outputs of seen tokens.
92+
93+
Args:
94+
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
95+
cache: a dense float Tensor, the cache of key and value.
96+
cache_update_index: int, or int Tensor. The index of current inputs
97+
in the whole sequence.
98+
99+
Returns:
100+
A (logits, hidden_states, cache) tuple. Where `logits` is the
101+
language model logits for the input token_ids, `hidden_states` is
102+
the final hidden representation of the input tokens, and `cache` is
103+
the decoding cache.
104+
"""
105+
x = self.backbone.token_embedding(token_ids)
106+
# Each decoder layer has a cache; we update them separately.
107+
updated_cache = []
108+
for i in range(self.backbone.num_layers):
109+
current_cache = cache[:, i, ...]
110+
x, next_cache = self.backbone.transformer_layers[i](
111+
x,
112+
self_attention_cache=current_cache,
113+
self_attention_cache_update_index=cache_update_index,
114+
)
115+
updated_cache.append(next_cache)
116+
cache = ops.stack(updated_cache, axis=1)
117+
hidden_states = x = self.backbone.layer_norm(x)
118+
logits = self.backbone.token_embedding(x, reverse=True)
119+
return logits, hidden_states, cache
120+
121+
def _build_cache(self, token_ids):
122+
"""Build an empty cache for use with `call_with_cache()`."""
123+
batch_size = ops.shape(token_ids)[0]
124+
max_length = ops.shape(token_ids)[1]
125+
num_layers = self.backbone.num_layers
126+
num_key_value_heads = self.backbone.num_key_value_heads
127+
head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads
128+
shape = [
129+
batch_size,
130+
num_layers,
131+
2,
132+
max_length,
133+
num_key_value_heads,
134+
head_dim,
135+
]
136+
cache = ops.zeros(shape, dtype=self.compute_dtype)
137+
# Seed the cache.
138+
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
139+
return hidden_states, cache
140+
141+
def generate_step(
142+
self,
143+
inputs,
144+
end_token_id=None,
145+
):
146+
"""A compilable generation function for a single batch of inputs.
147+
148+
This function represents the inner, XLA-compilable, generation function
149+
for a single batch of inputs. Inputs should have the same structure as
150+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
151+
152+
Args:
153+
inputs: A dictionary with two keys `"token_ids"` and
154+
`"padding_mask"` and batched tensor values.
155+
end_token_id: The id of the end token to stop on. If all
156+
sequences have produced a new `end_token_id`, generation
157+
will stop.
158+
"""
159+
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
160+
# Create and seed cache with a single forward pass.
161+
hidden_states, cache = self._build_cache(token_ids)
162+
# Compute the lengths of all user inputted tokens ids.
163+
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
164+
# Start at the first index that has no user inputted id.
165+
index = ops.min(row_lengths)
166+
167+
def next(prompt, cache, index):
168+
# The cache index is the index of our previous token.
169+
cache_update_index = index - 1
170+
batch_size = ops.shape(prompt)[0]
171+
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
172+
logits, hidden_states, cache = self.call_with_cache(
173+
prompt,
174+
cache,
175+
cache_update_index,
176+
)
177+
return (
178+
ops.squeeze(logits, axis=1),
179+
ops.squeeze(hidden_states, axis=1),
180+
cache,
181+
)
182+
183+
token_ids = self._sampler(
184+
next=next,
185+
prompt=token_ids,
186+
cache=cache,
187+
index=index,
188+
mask=padding_mask,
189+
end_token_id=end_token_id,
190+
hidden_states=hidden_states,
191+
)
192+
193+
# Compute an output padding mask with the token ids we updated.
194+
if end_token_id is not None:
195+
# Build a mask of `end_token_id` locations not in the original
196+
# prompt (not in locations where `padding_mask` is True).
197+
end_locations = ops.logical_and(
198+
ops.equal(token_ids, end_token_id),
199+
ops.logical_not(padding_mask),
200+
)
201+
end_locations = ops.cast(end_locations, "int32")
202+
# Use cumsum to get ones in all locations after end_locations.
203+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
204+
overflow = cumsum - end_locations
205+
# Our padding mask is the inverse of these overflow locations.
206+
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
207+
else:
208+
# Without early stopping, all locations will have been updated.
209+
padding_mask = ops.ones_like(token_ids, dtype="bool")
210+
return {
211+
"token_ids": token_ids,
212+
"padding_mask": padding_mask,
213+
}

0 commit comments

Comments
 (0)