Skip to content

Commit 32f166a

Browse files
Add review comments
1 parent c5679d8 commit 32f166a

File tree

8 files changed

+295
-325
lines changed

8 files changed

+295
-325
lines changed

keras_nlp/layers/modeling/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
9797
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)
9898

9999
def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):
100-
freq_range = ops.cast(ops.arange(0, rotary_dim, 2), self.compute_dtype)
100+
freq_range = ops.arange(0, rotary_dim, 2)
101101
freq_range = ops.cast(freq_range, self.compute_dtype)
102102
freq_range = freq_range / ops.cast(
103103
self.scaling_factor, self.compute_dtype

keras_nlp/models/mistral/mistral_attention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class CachedMistralAttention(keras.layers.Layer):
2929

3030
def __init__(
3131
self,
32-
*,
3332
num_query_heads,
3433
num_key_value_heads,
3534
rope_max_wavelength=10000,
@@ -124,6 +123,9 @@ def build(self, inputs_shape):
124123
dtype=self.compute_dtype,
125124
)
126125

126+
self._dot_product_equation = "bquh,bkuh->buqk"
127+
self._combine_equation = "buqk,bkuh->bquh"
128+
127129
self.built = True
128130

129131
def call(
@@ -258,7 +260,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
258260
return self._softmax(attention_scores)
259261

260262
def _compute_attention(self, query, key, value, attention_mask=None):
261-
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)
263+
attention_scores = ops.einsum(self._dot_product_equation, key, query)
262264

263265
norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
264266

@@ -268,7 +270,7 @@ def _compute_attention(self, query, key, value, attention_mask=None):
268270
attention_scores, attention_mask
269271
)
270272
attention_output = ops.einsum(
271-
"acbe,aecd->abcd", attention_scores, value
273+
self._combine_equation, attention_scores, value
272274
)
273275

274276
return attention_output

keras_nlp/models/mistral/mistral_backbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class MistralBackbone(Backbone):
6464
layers in each transformer decoder. Only `sliding_window` number of tokens
6565
are saved in the cache and used to generate the next token.
6666
Defaults to `512`.
67+
dtype (str, optional): The dtype policy for the mistral model.
6768
6869
Examples:
6970
@@ -95,7 +96,6 @@ class MistralBackbone(Backbone):
9596

9697
def __init__(
9798
self,
98-
*,
9999
vocabulary_size,
100100
num_layers,
101101
num_query_heads,

keras_nlp/models/mistral/mistral_layer_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from keras_nlp.backend import ops
1616

1717

18-
# TODO: Deprecate this in favor of `keras.layers.LayerNormalization` once
19-
# Keras 2 support is removed.
18+
# TODO: Deprecate this in favor of
19+
# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is
20+
# removed.
2021
class MistralLayerNormalization(keras.layers.Layer):
2122
"""A normalization layer for Mistral that implements RMS normalization."""
2223

keras_nlp/models/mistral/mistral_transformer_decoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class MistralTransformerDecoder(keras.layers.Layer):
3131

3232
def __init__(
3333
self,
34-
*,
3534
intermediate_dim,
3635
num_query_heads,
3736
num_key_value_heads,

tools/checkpoint_conversion/convert_mistral_checkpoints.py

Lines changed: 285 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,298 @@
1313
# limitations under the License.
1414
import json
1515
import pathlib
16+
from dataclasses import dataclass
17+
from pathlib import Path
18+
from typing import Optional
19+
from typing import Tuple
1620

1721
import torch
22+
from torch import nn
1823

1924
from keras_nlp.models import MistralBackbone
2025

21-
from .scripts.mistral_torch import ModelArgs
22-
from .scripts.mistral_torch import Transformer as TorchTransformer
23-
2426
MODEL_PATH = pathlib.Path("mistral-7B-v0.1")
2527

28+
# Torch model taken from:
29+
# https://github.com/mistralai/mistral-src/blob/147c4e68279b90eb61b19bdea44e16f5539d5a5d/one_file_ref.py
30+
31+
32+
@dataclass
33+
class ModelArgs:
34+
dim: int
35+
n_layers: int
36+
head_dim: int
37+
hidden_dim: int
38+
n_heads: int
39+
n_kv_heads: int
40+
sliding_window: int
41+
norm_eps: float
42+
vocab_size: int
43+
44+
max_batch_size: int = 0
45+
46+
47+
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int):
48+
keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
49+
values = torch.repeat_interleave(values, repeats=repeats, dim=2)
50+
return keys, values
51+
52+
53+
def _reshape_for_broadcast(
54+
freqs_cis: torch.Tensor, x: torch.Tensor
55+
) -> torch.Tensor:
56+
"""
57+
freqs_cis: complex - (seq_len, head_dim / 2)
58+
x: complex - (bsz, seq_len, head_dim / 2)
59+
"""
60+
ndim = x.ndim
61+
assert 1 < ndim
62+
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
63+
freqs_cis.shape,
64+
(x.shape[1], x.shape[-1]),
65+
)
66+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
67+
return freqs_cis.view(*shape)
68+
69+
70+
def apply_rotary_emb(
71+
xq: torch.Tensor,
72+
xk: torch.Tensor,
73+
freqs_cis: torch.Tensor,
74+
) -> Tuple[torch.Tensor, torch.Tensor]:
75+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
76+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
77+
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
78+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
79+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
80+
return xq_out.type_as(xq), xk_out.type_as(xk)
81+
82+
83+
class Attention(nn.Module):
84+
def __init__(self, args: ModelArgs):
85+
super().__init__()
86+
self.args = args
87+
88+
self.n_heads: int = args.n_heads
89+
self.n_kv_heads: int = args.n_kv_heads
90+
91+
self.repeats = self.n_heads // self.n_kv_heads
92+
self.sliding_window = self.args.sliding_window
93+
94+
self.scale = self.args.head_dim**-0.5
95+
96+
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
97+
self.wk = nn.Linear(
98+
args.dim, args.n_kv_heads * args.head_dim, bias=False
99+
)
100+
self.wv = nn.Linear(
101+
args.dim, args.n_kv_heads * args.head_dim, bias=False
102+
)
103+
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
104+
self.cache_k = torch.empty(
105+
(
106+
args.max_batch_size,
107+
args.sliding_window,
108+
self.n_kv_heads,
109+
self.args.head_dim,
110+
),
111+
dtype=torch.float16,
112+
)
113+
self.cache_v = torch.empty(
114+
(
115+
args.max_batch_size,
116+
args.sliding_window,
117+
self.n_kv_heads,
118+
self.args.head_dim,
119+
),
120+
dtype=torch.float16,
121+
)
122+
123+
def forward(
124+
self,
125+
x: torch.Tensor,
126+
freqs_cis: torch.Tensor,
127+
positions: torch.Tensor,
128+
mask: Optional[torch.Tensor],
129+
) -> torch.Tensor:
130+
bsz, seqlen, _ = x.shape
131+
132+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
133+
xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim)
134+
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
135+
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
136+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
137+
138+
# The cache is a rotating buffer
139+
scatter_pos = (positions[-self.sliding_window :] % self.sliding_window)[
140+
None, :, None, None
141+
]
142+
scatter_pos = scatter_pos.repeat(
143+
bsz, 1, self.n_kv_heads, self.args.head_dim
144+
)
145+
self.cache_k[:bsz].scatter_(
146+
dim=1,
147+
index=scatter_pos,
148+
src=xk[:, -self.sliding_window :].to(self.cache_k.dtype),
149+
)
150+
self.cache_v[:bsz].scatter_(
151+
dim=1,
152+
index=scatter_pos,
153+
src=xv[:, -self.sliding_window :].to(self.cache_v.dtype),
154+
)
155+
156+
if positions.shape[0] > 1:
157+
# prefill
158+
key, value = repeat_kv(xk, xv, self.repeats)
159+
else:
160+
cur_pos = positions[-1].item() + 1
161+
key, value = repeat_kv(
162+
self.cache_k[:bsz, :cur_pos, ...].to(xk.dtype),
163+
self.cache_v[:bsz, :cur_pos, ...].to(xv.dtype),
164+
self.repeats,
165+
)
166+
167+
query = xq.transpose(1, 2)
168+
key = key.transpose(1, 2)
169+
value = value.transpose(1, 2)
170+
# scores : [bsz, n_heads, seqlen | 1, seqlen]
171+
scores = torch.matmul(query, key.transpose(2, 3)) * self.scale
172+
173+
if mask is not None:
174+
scores += mask[None, None, ...]
175+
176+
scores = scores.float()
177+
scores = nn.functional.softmax(scores, dim=-1).type_as(query)
178+
output = torch.matmul(
179+
scores, value
180+
) # (bs, n_local_heads, slen, head_dim)
181+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
182+
return self.wo(output)
183+
184+
185+
class FeedForward(nn.Module):
186+
def __init__(self, args: ModelArgs):
187+
super().__init__()
188+
189+
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
190+
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
191+
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
192+
193+
def forward(self, x) -> torch.Tensor:
194+
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
195+
196+
197+
class RMSNorm(torch.nn.Module):
198+
def __init__(self, dim: int, eps: float = 1e-6):
199+
super().__init__()
200+
self.eps = eps
201+
self.weight = nn.Parameter(torch.ones(dim))
202+
203+
def _norm(self, x):
204+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
205+
206+
def forward(self, x):
207+
output = self._norm(x.float()).type_as(x)
208+
return output * self.weight
209+
210+
211+
class TransformerBlock(nn.Module):
212+
def __init__(self, args: ModelArgs):
213+
super().__init__()
214+
self.n_heads = args.n_heads
215+
self.dim = args.dim
216+
self.attention = Attention(args)
217+
self.feed_forward = FeedForward(args=args)
218+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
219+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
220+
self.args = args
221+
222+
def forward(
223+
self,
224+
x: torch.Tensor,
225+
freqs_cis: torch.Tensor,
226+
positions: torch.Tensor,
227+
mask: Optional[torch.Tensor],
228+
) -> torch.Tensor:
229+
r = self.attention.forward(
230+
self.attention_norm(x), freqs_cis, positions, mask
231+
)
232+
h = x + r
233+
r = self.feed_forward.forward(self.ffn_norm(h))
234+
out = h + r
235+
return out
236+
237+
238+
def precompute_freqs_cis(
239+
dim: int, end: int, theta: float = 10000.0
240+
) -> torch.Tensor:
241+
freqs = 1.0 / (
242+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
243+
)
244+
t = torch.arange(end, device=freqs.device) # type: ignore
245+
freqs = torch.outer(t, freqs).float() # type: ignore
246+
return torch.polar(torch.ones_like(freqs), freqs) # complex64
247+
248+
249+
class TorchTransformer(nn.Module):
250+
def __init__(self, args: ModelArgs):
251+
super().__init__()
252+
self.args = args
253+
self.vocab_size = args.vocab_size
254+
self.n_layers = args.n_layers
255+
assert self.vocab_size > 0
256+
257+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
258+
259+
self.layers = torch.nn.ModuleList(
260+
[TransformerBlock(args=args) for _ in range(args.n_layers)]
261+
)
262+
263+
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
264+
265+
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
266+
267+
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000)
268+
269+
def forward(
270+
self,
271+
input_ids: torch.Tensor,
272+
positions: torch.Tensor,
273+
):
274+
h = self.tok_embeddings(input_ids)
275+
freqs_cis = self.freqs_cis[positions]
276+
277+
mask: Optional[torch.Tensor] = None
278+
if input_ids.shape[1] > 1:
279+
seqlen = input_ids.shape[1]
280+
tensor = torch.full(
281+
(seqlen, seqlen),
282+
dtype=h.dtype,
283+
fill_value=1,
284+
device=h.device,
285+
)
286+
mask = torch.tril(tensor, diagonal=0).to(h.dtype)
287+
# make the mask banded to account for sliding window
288+
mask = torch.triu(mask, diagonal=-self.args.sliding_window)
289+
mask = torch.log(mask)
290+
291+
for layer in self.layers:
292+
h = layer(h, freqs_cis, positions, mask)
293+
294+
return self.output(self.norm(h)).float()
295+
296+
@staticmethod
297+
def from_folder(
298+
folder: Path, max_batch_size: int = 1, device="cpu", dtype=torch.float16
299+
):
300+
with open(folder / "params.json", "r") as f:
301+
model_args = ModelArgs(**json.loads(f.read()))
302+
model_args.max_batch_size = max_batch_size
303+
model = TorchTransformer(model_args).to(device=device, dtype=dtype)
304+
loaded = torch.load(folder / "consolidated.00.pth")
305+
model.load_state_dict(loaded)
306+
return model
307+
26308

27309
def port_weights(
28310
model_k3: MistralBackbone, model_torch: TorchTransformer, params: ModelArgs

tools/checkpoint_conversion/scripts/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)