Skip to content

Commit ad2ae64

Browse files
Add docs; Make args keyword-only; Cosmetic fixes
1 parent 45b03a5 commit ad2ae64

File tree

7 files changed

+553
-2
lines changed

7 files changed

+553
-2
lines changed

keras_nlp/models/mistral/mistral_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
# TODO(tirthasheshpatel): Generalize the attention layer
2525
# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer
2626
# TODO(tirthasheshpatel): Use flash attention
27-
# TODO(tirthasheshpatel): Add dropout
2827
class CachedMistralAttention(keras.layers.Layer):
28+
"""A cached grounded query attention layer with sliding window."""
29+
2930
def __init__(
3031
self,
32+
*,
3133
num_query_heads,
3234
num_key_value_heads,
3335
rope_max_wavelength=10000,

keras_nlp/models/mistral/mistral_backbone.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,72 @@ def _mistral_kernel_initializer(stddev=0.02):
3030

3131
@keras_nlp_export("keras_nlp.models.MistralBackbone")
3232
class MistralBackbone(Backbone):
33+
"""
34+
The Mistral Transformer core architecture with hyperparameters.
35+
36+
This network implements a Transformer-based decoder network,
37+
Mistral, as described in
38+
["Mistral 7B"](https://arxiv.org/pdf/2310.06825.pdf).
39+
It includes the embedding lookups and transformer layers.
40+
41+
The default constructor gives a fully customizable, randomly initialized
42+
Mistral model with any number of layers, heads, and embedding
43+
dimensions. To load preset architectures and weights, use the `from_preset`
44+
constructor.
45+
46+
Args:
47+
vocabulary_size (int): The size of the token vocabulary.
48+
num_layers (int): The number of transformer layers.
49+
num_query_heads (int): The number of query attention heads for
50+
each transformer.
51+
hidden_dim (int): The size of the transformer encoding and pooling layers.
52+
intermediate_dim (int): The output dimension of the first Dense layer in a
53+
three-layer feedforward network for each transformer.
54+
num_key_value_heads (int): The number of key and value attention heads for
55+
each transformer.
56+
rope_max_wavelength (int, optional): The maximum angular wavelength of the
57+
sine/cosine curves, for rotary embeddings. Defaults to `10000`.
58+
rope_scaling_factor (float, optional): The scaling factor for calculation
59+
of roatary embedding. Defaults to `1.0`.
60+
layer_norm_epsilon (float, optional): Epsilon for the layer normalization
61+
layers in the transformer decoder. Defaults to `1e-6`.
62+
sliding_window (int, optional): The sliding window for the mistral
63+
attention layers. This controls the maximum cache size for the attention
64+
layers in each transformer decoder. Only `sliding_window` number of tokens
65+
are saved in the cache and used to generate the next token.
66+
Defaults to `512`.
67+
68+
Examples:
69+
70+
```python
71+
input_data = {
72+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
73+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
74+
}
75+
76+
# Pretrained Mistral decoder.
77+
model = keras_nlp.models.MistralBackbone.from_preset("mistral7b_base_en")
78+
model(input_data)
79+
80+
# Randomly initialized Mistral decoder with custom config.
81+
model = keras_nlp.models.MistralBackbone(
82+
vocabulary_size=10,
83+
hidden_dim=512,
84+
num_layers=2,
85+
num_query_heads=32,
86+
num_key_value_heads=8,
87+
intermediate_dim=1024,
88+
sliding_window=512,
89+
layer_norm_epsilon=1e-6,
90+
dtype="float32"
91+
)
92+
model(input_data)
93+
```
94+
"""
95+
3396
def __init__(
3497
self,
98+
*,
3599
vocabulary_size,
36100
num_layers,
37101
num_query_heads,
@@ -42,6 +106,7 @@ def __init__(
42106
rope_scaling_factor=1.0,
43107
layer_norm_epsilon=1e-6,
44108
sliding_window=512,
109+
dropout=0,
45110
**kwargs,
46111
):
47112
# Get the dtype
@@ -76,6 +141,7 @@ def __init__(
76141
activation=ops.silu,
77142
kernel_initializer=_mistral_kernel_initializer(stddev=0.02),
78143
sliding_window=sliding_window,
144+
dropout=dropout,
79145
dtype=dtype,
80146
name=f"transformer_layer_{i}",
81147
)(x, decoder_padding_mask=padding_mask)
@@ -107,6 +173,7 @@ def __init__(
107173
self.rope_scaling_factor = rope_scaling_factor
108174
self.sliding_window = sliding_window
109175
self.layer_norm_epsilon = layer_norm_epsilon
176+
self.dropout = dropout
110177
self.token_embedding = token_embedding_layer
111178

112179
def get_config(self):
@@ -123,6 +190,7 @@ def get_config(self):
123190
"num_key_value_heads": self.num_key_value_heads,
124191
"sliding_window": self.sliding_window,
125192
"layer_norm_epsilon": self.layer_norm_epsilon,
193+
"dropout": self.dropout,
126194
}
127195
)
128196
return config

keras_nlp/models/mistral/mistral_layer_norm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
from keras_nlp.backend import ops
1616

1717

18+
# TODO: Deprecate this in favor of `keras.layers.LayerNormalization` once
19+
# Keras 2 support is removed.
1820
class MistralLayerNormalization(keras.layers.Layer):
21+
"""A normalization layer for Mistral that implements RMS normalization."""
22+
1923
def __init__(self, epsilon=1e-6, **kwargs):
2024
super().__init__(**kwargs)
2125
self._epsilon = epsilon

keras_nlp/models/mistral/mistral_transformer_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
from keras_nlp.utils.keras_utils import clone_initializer
2727

2828

29-
# TODO(tirthasheshpatel): Add dropout
3029
class MistralTransformerDecoder(keras.layers.Layer):
30+
"""A Transformer decoder layer for the Mistral backbone."""
31+
3132
def __init__(
3233
self,
34+
*,
3335
intermediate_dim,
3436
num_query_heads,
3537
num_key_value_heads,
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
import pathlib
16+
17+
import torch
18+
19+
from keras_nlp.models import MistralBackbone
20+
21+
from .scripts.mistral_torch import ModelArgs
22+
from .scripts.mistral_torch import Transformer as TorchTransformer
23+
24+
MODEL_PATH = pathlib.Path("mistral-7B-v0.1")
25+
26+
27+
def port_weights(
28+
model_k3: MistralBackbone, model_torch: TorchTransformer, params: ModelArgs
29+
):
30+
model_k3.get_layer("token_embedding").embeddings.assign(
31+
model_torch.tok_embeddings.weight.detach().cpu().numpy()
32+
)
33+
34+
for i in range(model_k3.num_layers):
35+
model_k3.get_layer(
36+
f"transformer_layer_{i}"
37+
)._self_attention_layer._key_dense.set_weights(
38+
[
39+
model_torch.layers[i]
40+
.attention.wk.weight.T.reshape(
41+
params.dim, params.n_kv_heads, params.head_dim
42+
)
43+
.detach()
44+
.cpu()
45+
.numpy()
46+
]
47+
)
48+
model_k3.get_layer(
49+
f"transformer_layer_{i}"
50+
)._self_attention_layer._query_dense.set_weights(
51+
[
52+
model_torch.layers[i]
53+
.attention.wq.weight.T.reshape(
54+
params.dim, params.n_heads, params.head_dim
55+
)
56+
.detach()
57+
.cpu()
58+
.numpy()
59+
]
60+
)
61+
model_k3.get_layer(
62+
f"transformer_layer_{i}"
63+
)._self_attention_layer._value_dense.set_weights(
64+
[
65+
model_torch.layers[i]
66+
.attention.wv.weight.T.reshape(
67+
params.dim, params.n_kv_heads, params.head_dim
68+
)
69+
.detach()
70+
.cpu()
71+
.numpy()
72+
]
73+
)
74+
model_k3.get_layer(
75+
f"transformer_layer_{i}"
76+
)._self_attention_layer._output_dense.set_weights(
77+
[
78+
model_torch.layers[i]
79+
.attention.wo.weight.T.reshape(
80+
params.n_heads, params.head_dim, params.dim
81+
)
82+
.detach()
83+
.cpu()
84+
.numpy()
85+
]
86+
)
87+
model_k3.get_layer(
88+
f"transformer_layer_{i}"
89+
)._self_attention_layernorm.set_weights(
90+
[model_torch.layers[i].attention_norm.weight.detach().cpu().numpy()]
91+
)
92+
model_k3.get_layer(
93+
f"transformer_layer_{i}"
94+
)._feedforward_intermediate_dense.set_weights(
95+
[
96+
model_torch.layers[i]
97+
.feed_forward.w3.weight.T.detach()
98+
.cpu()
99+
.numpy()
100+
]
101+
)
102+
model_k3.get_layer(
103+
f"transformer_layer_{i}"
104+
)._feedforward_output_dense.set_weights(
105+
[
106+
model_torch.layers[i]
107+
.feed_forward.w2.weight.T.detach()
108+
.cpu()
109+
.numpy()
110+
]
111+
)
112+
model_k3.get_layer(
113+
f"transformer_layer_{i}"
114+
)._feedforward_gate_dense.set_weights(
115+
[
116+
model_torch.layers[i]
117+
.feed_forward.w1.weight.T.detach()
118+
.cpu()
119+
.numpy()
120+
]
121+
)
122+
model_k3.get_layer(
123+
f"transformer_layer_{i}"
124+
)._feedforward_layernorm.set_weights(
125+
[model_torch.layers[i].ffn_norm.weight.detach().cpu().numpy()]
126+
)
127+
128+
model_k3.get_layer("sequence_output_layernorm").set_weights(
129+
[model_torch.norm.weight.detach().cpu().numpy()]
130+
)
131+
model_k3.get_layer("token_embedding").reverse_embeddings.assign(
132+
model_torch.output.weight.T.detach().cpu().numpy()
133+
)
134+
135+
136+
if __name__ == "__main__":
137+
with open(MODEL_PATH / "params.json", "r") as params_file:
138+
params = ModelArgs(**json.load(params_file))
139+
140+
model_torch = TorchTransformer.from_folder(
141+
MODEL_PATH, device="cpu", dtype=torch.float16
142+
)
143+
print("Torch model loaded")
144+
model_k3 = MistralBackbone(
145+
vocabulary_size=32000,
146+
hidden_dim=4096,
147+
num_layers=32,
148+
num_query_heads=32,
149+
num_key_value_heads=8,
150+
intermediate_dim=14336,
151+
sliding_window=4096,
152+
layer_norm_epsilon=1e-6,
153+
dtype="float16",
154+
)
155+
print("Keras 3 model loaded.")
156+
157+
port_weights(model_k3, model_torch, params)
158+
print("Weight transfer done.")
159+
160+
model_k3.save_weights("mistral_7b.weights.h5")
161+
print("Weights saved.")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)