Skip to content

Commit 2b8c099

Browse files
authored
Add ALBERT Conversion Script (#736)
* Add ALBERT Conversion Script * Minor typo fix * Reformat * Fixes * Small print statement change
1 parent 10d451e commit 2b8c099

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import shutil
16+
17+
import numpy as np
18+
import tensorflow as tf
19+
import transformers
20+
from absl import app
21+
from absl import flags
22+
23+
import keras_nlp
24+
from tools.checkpoint_conversion.checkpoint_conversion_utils import (
25+
get_md5_checksum,
26+
)
27+
28+
PRESET_MAP = {
29+
"albert_base_en_uncased": "albert-base-v2",
30+
"albert_large_en_uncased": "albert-large-v2",
31+
"albert_extra_large_en_uncased": "albert-xlarge-v2",
32+
"albert_extra_extra_large_en_uncased": "albert-xxlarge-v2",
33+
}
34+
35+
36+
FLAGS = flags.FLAGS
37+
flags.DEFINE_string(
38+
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
39+
)
40+
41+
42+
def convert_checkpoints(hf_model):
43+
print("\n-> Convert original weights to KerasNLP format.")
44+
45+
print("\n-> Load KerasNLP model.")
46+
keras_nlp_model = keras_nlp.models.AlbertBackbone.from_preset(
47+
FLAGS.preset, load_weights=False
48+
)
49+
50+
hf_wts = hf_model.state_dict()
51+
print("Original weights:")
52+
print(list(hf_wts.keys()))
53+
54+
num_heads = keras_nlp_model.num_heads
55+
hidden_dim = keras_nlp_model.hidden_dim
56+
57+
keras_nlp_model.get_layer("token_embedding").embeddings.assign(
58+
hf_wts["embeddings.word_embeddings.weight"]
59+
)
60+
keras_nlp_model.get_layer("position_embedding").position_embeddings.assign(
61+
hf_wts["embeddings.position_embeddings.weight"]
62+
)
63+
keras_nlp_model.get_layer("segment_embedding").embeddings.assign(
64+
hf_wts["embeddings.token_type_embeddings.weight"]
65+
)
66+
67+
keras_nlp_model.get_layer("embeddings_layer_norm").gamma.assign(
68+
hf_wts["embeddings.LayerNorm.weight"]
69+
)
70+
keras_nlp_model.get_layer("embeddings_layer_norm").beta.assign(
71+
hf_wts["embeddings.LayerNorm.bias"]
72+
)
73+
74+
keras_nlp_model.get_layer("embedding_projection").kernel.assign(
75+
hf_wts["encoder.embedding_hidden_mapping_in.weight"].T
76+
)
77+
keras_nlp_model.get_layer("embedding_projection").bias.assign(
78+
hf_wts["encoder.embedding_hidden_mapping_in.bias"]
79+
)
80+
81+
for i in range(keras_nlp_model.num_groups):
82+
for j in range(keras_nlp_model.num_inner_repetitions):
83+
keras_nlp_model.get_layer(
84+
f"group_{i}_inner_layer_{j}"
85+
)._self_attention_layer._query_dense.kernel.assign(
86+
hf_wts[
87+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.query.weight"
88+
]
89+
.transpose(1, 0)
90+
.reshape((hidden_dim, num_heads, -1))
91+
.numpy()
92+
)
93+
keras_nlp_model.get_layer(
94+
f"group_{i}_inner_layer_{j}"
95+
)._self_attention_layer._query_dense.bias.assign(
96+
hf_wts[
97+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.query.bias"
98+
]
99+
.reshape((num_heads, -1))
100+
.numpy()
101+
)
102+
103+
keras_nlp_model.get_layer(
104+
f"group_{i}_inner_layer_{j}"
105+
)._self_attention_layer._key_dense.kernel.assign(
106+
hf_wts[
107+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.key.weight"
108+
]
109+
.transpose(1, 0)
110+
.reshape((hidden_dim, num_heads, -1))
111+
.numpy()
112+
)
113+
keras_nlp_model.get_layer(
114+
f"group_{i}_inner_layer_{j}"
115+
)._self_attention_layer._key_dense.bias.assign(
116+
hf_wts[
117+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.key.bias"
118+
]
119+
.reshape((num_heads, -1))
120+
.numpy()
121+
)
122+
123+
keras_nlp_model.get_layer(
124+
f"group_{i}_inner_layer_{j}"
125+
)._self_attention_layer._value_dense.kernel.assign(
126+
hf_wts[
127+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.value.weight"
128+
]
129+
.transpose(1, 0)
130+
.reshape((hidden_dim, num_heads, -1))
131+
.numpy()
132+
)
133+
keras_nlp_model.get_layer(
134+
f"group_{i}_inner_layer_{j}"
135+
)._self_attention_layer._value_dense.bias.assign(
136+
hf_wts[
137+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.value.bias"
138+
]
139+
.reshape((num_heads, -1))
140+
.numpy()
141+
)
142+
143+
keras_nlp_model.get_layer(
144+
f"group_{i}_inner_layer_{j}"
145+
)._self_attention_layer._output_dense.kernel.assign(
146+
hf_wts[
147+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.dense.weight"
148+
]
149+
.transpose(1, 0)
150+
.reshape((num_heads, -1, hidden_dim))
151+
.numpy()
152+
)
153+
keras_nlp_model.get_layer(
154+
f"group_{i}_inner_layer_{j}"
155+
)._self_attention_layer._output_dense.bias.assign(
156+
hf_wts[
157+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.dense.bias"
158+
].numpy()
159+
)
160+
161+
keras_nlp_model.get_layer(
162+
f"group_{i}_inner_layer_{j}"
163+
)._self_attention_layernorm.gamma.assign(
164+
hf_wts[
165+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.LayerNorm.weight"
166+
].numpy()
167+
)
168+
keras_nlp_model.get_layer(
169+
f"group_{i}_inner_layer_{j}"
170+
)._self_attention_layernorm.beta.assign(
171+
hf_wts[
172+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.attention.LayerNorm.bias"
173+
].numpy()
174+
)
175+
176+
keras_nlp_model.get_layer(
177+
f"group_{i}_inner_layer_{j}"
178+
)._feedforward_intermediate_dense.kernel.assign(
179+
hf_wts[
180+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.ffn.weight"
181+
]
182+
.transpose(1, 0)
183+
.numpy()
184+
)
185+
keras_nlp_model.get_layer(
186+
f"group_{i}_inner_layer_{j}"
187+
)._feedforward_intermediate_dense.bias.assign(
188+
hf_wts[
189+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.ffn.bias"
190+
].numpy()
191+
)
192+
193+
keras_nlp_model.get_layer(
194+
f"group_{i}_inner_layer_{j}"
195+
)._feedforward_output_dense.kernel.assign(
196+
hf_wts[
197+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.ffn_output.weight"
198+
]
199+
.transpose(1, 0)
200+
.numpy()
201+
)
202+
keras_nlp_model.get_layer(
203+
f"group_{i}_inner_layer_{j}"
204+
)._feedforward_output_dense.bias.assign(
205+
hf_wts[
206+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.ffn_output.bias"
207+
].numpy()
208+
)
209+
210+
keras_nlp_model.get_layer(
211+
f"group_{i}_inner_layer_{j}"
212+
)._feedforward_layernorm.gamma.assign(
213+
hf_wts[
214+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.full_layer_layer_norm.weight"
215+
].numpy()
216+
)
217+
keras_nlp_model.get_layer(
218+
f"group_{i}_inner_layer_{j}"
219+
)._feedforward_layernorm.beta.assign(
220+
hf_wts[
221+
f"encoder.albert_layer_groups.{i}.albert_layers.{j}.full_layer_layer_norm.bias"
222+
].numpy()
223+
)
224+
225+
keras_nlp_model.get_layer("pooled_dense").kernel.assign(
226+
hf_wts["pooler.weight"].transpose(1, 0).numpy()
227+
)
228+
keras_nlp_model.get_layer("pooled_dense").bias.assign(
229+
hf_wts["pooler.bias"].numpy()
230+
)
231+
232+
# Save the model.
233+
print("\n-> Save KerasNLP model weights.")
234+
keras_nlp_model.save_weights(os.path.join(FLAGS.preset, "model.h5"))
235+
236+
return keras_nlp_model
237+
238+
239+
def extract_vocab(hf_tokenizer):
240+
spm_path = os.path.join(FLAGS.preset, "spiece.model")
241+
print(f"\n-> Save KerasNLP SPM vocabulary file to `{spm_path}`.")
242+
243+
shutil.copyfile(
244+
transformers.utils.hub.get_file_from_repo(
245+
hf_tokenizer.name_or_path, "spiece.model"
246+
),
247+
spm_path,
248+
)
249+
250+
keras_nlp_tokenizer = keras_nlp.models.AlbertTokenizer(
251+
proto=spm_path,
252+
)
253+
keras_nlp_preprocessor = keras_nlp.models.AlbertPreprocessor(
254+
keras_nlp_tokenizer
255+
)
256+
257+
print("-> Print MD5 checksum of the vocab files.")
258+
print(f"`{spm_path}` md5sum: ", get_md5_checksum(spm_path))
259+
260+
return keras_nlp_preprocessor
261+
262+
263+
def check_output(
264+
keras_nlp_preprocessor,
265+
keras_nlp_model,
266+
hf_tokenizer,
267+
hf_model,
268+
):
269+
print("\n-> Check the outputs.")
270+
sample_text = ["cricket is awesome, easily the best sport in the world!"]
271+
272+
# KerasNLP
273+
keras_nlp_inputs = keras_nlp_preprocessor(tf.constant(sample_text))
274+
keras_nlp_output = keras_nlp_model.predict(keras_nlp_inputs)[
275+
"sequence_output"
276+
]
277+
278+
# HF
279+
hf_inputs = hf_tokenizer(
280+
sample_text, padding="max_length", return_tensors="pt"
281+
)
282+
hf_output = hf_model(**hf_inputs).last_hidden_state
283+
284+
print("KerasNLP output:", keras_nlp_output[0, 0, :10])
285+
print("HF output:", hf_output[0, 0, :10])
286+
print("Difference:", np.mean(keras_nlp_output - hf_output.detach().numpy()))
287+
288+
# Show the MD5 checksum of the model weights.
289+
print(
290+
"Model md5sum: ",
291+
get_md5_checksum(os.path.join(FLAGS.preset, "model.h5")),
292+
)
293+
294+
295+
def main(_):
296+
os.makedirs(FLAGS.preset)
297+
298+
hf_model_name = PRESET_MAP[FLAGS.preset]
299+
300+
print("\n-> Load HF model and HF tokenizer.")
301+
hf_model = transformers.AutoModel.from_pretrained(hf_model_name)
302+
hf_model.eval()
303+
hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_name)
304+
305+
keras_nlp_model = convert_checkpoints(hf_model)
306+
print("\n -> Load KerasNLP preprocessor.")
307+
keras_nlp_preprocessor = extract_vocab(hf_tokenizer)
308+
309+
check_output(
310+
keras_nlp_preprocessor,
311+
keras_nlp_model,
312+
hf_tokenizer,
313+
hf_model,
314+
)
315+
316+
317+
if __name__ == "__main__":
318+
flags.mark_flag_as_required("preset")
319+
app.run(main)

0 commit comments

Comments
 (0)