Skip to content

Commit f54d24c

Browse files
authored
Add DeBERTaV3 Conversion Script (#633)
* Add DeBERTaV3 conversion script * Fix padding issue
1 parent 7774071 commit f54d24c

File tree

1 file changed

+326
-0
lines changed

1 file changed

+326
-0
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
import os
16+
17+
import numpy as np
18+
import requests
19+
import tensorflow as tf
20+
import transformers
21+
from absl import app
22+
from absl import flags
23+
24+
from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone
25+
from keras_nlp.models.deberta_v3.deberta_v3_preprocessor import (
26+
DebertaV3Preprocessor,
27+
)
28+
from keras_nlp.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer
29+
from tools.checkpoint_conversion.checkpoint_conversion_utils import (
30+
get_md5_checksum,
31+
)
32+
33+
PRESET_MAP = {
34+
"deberta_v3_extra_small_en": "microsoft/deberta-v3-xsmall",
35+
"deberta_v3_small_en": "microsoft/deberta-v3-small",
36+
"deberta_v3_base_en": "microsoft/deberta-v3-base",
37+
"deberta_v3_large_en": "microsoft/deberta-v3-large",
38+
"deberta_v3_base_multi": "microsoft/mdeberta-v3-base",
39+
}
40+
41+
EXTRACT_DIR = "./{}"
42+
43+
FLAGS = flags.FLAGS
44+
flags.DEFINE_string(
45+
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
46+
)
47+
48+
49+
def download_files(preset, hf_model_name):
50+
print("-> Download original vocabulary and config.")
51+
52+
extract_dir = EXTRACT_DIR.format(preset)
53+
if not os.path.exists(extract_dir):
54+
os.makedirs(extract_dir)
55+
56+
# Config.
57+
config_path = os.path.join(extract_dir, "config.json")
58+
response = requests.get(
59+
f"https://huggingface.co/{hf_model_name}/raw/main/config.json"
60+
)
61+
open(config_path, "wb").write(response.content)
62+
print(f"`{config_path}`")
63+
64+
# Vocab.
65+
spm_path = os.path.join(extract_dir, "spm.model")
66+
response = requests.get(
67+
f"https://huggingface.co/{hf_model_name}/resolve/main/spm.model"
68+
)
69+
open(spm_path, "wb").write(response.content)
70+
print(f"`{spm_path}`")
71+
72+
73+
def define_preprocessor(preset, hf_model_name):
74+
print("\n-> Define the tokenizers.")
75+
extract_dir = EXTRACT_DIR.format(preset)
76+
spm_path = os.path.join(extract_dir, "spm.model")
77+
78+
keras_nlp_tokenizer = DebertaV3Tokenizer(proto=spm_path)
79+
80+
# Avoid having padding tokens. This is because the representations of the
81+
# padding token may be vastly different from the representations computed in
82+
# the original model. See https://github.com/keras-team/keras/pull/16619#issuecomment-1156338394.
83+
sequence_length = 14
84+
if preset == "deberta_v3_base_multi":
85+
sequence_length = 17
86+
keras_nlp_preprocessor = DebertaV3Preprocessor(
87+
keras_nlp_tokenizer, sequence_length=sequence_length
88+
)
89+
90+
hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_name)
91+
92+
print("\n-> Print MD5 checksum of the vocab files.")
93+
print(f"`{spm_path}` md5sum: ", get_md5_checksum(spm_path))
94+
95+
return keras_nlp_preprocessor, hf_tokenizer
96+
97+
98+
def convert_checkpoints(preset, keras_nlp_model, hf_model):
99+
print("\n-> Convert original weights to KerasNLP format.")
100+
101+
extract_dir = EXTRACT_DIR.format(preset)
102+
config_path = os.path.join(extract_dir, "config.json")
103+
104+
# Build config.
105+
cfg = {}
106+
with open(config_path, "r") as pt_cfg_handler:
107+
pt_cfg = json.load(pt_cfg_handler)
108+
cfg["vocabulary_size"] = pt_cfg["vocab_size"]
109+
cfg["num_layers"] = pt_cfg["num_hidden_layers"]
110+
cfg["num_heads"] = pt_cfg["num_attention_heads"]
111+
cfg["hidden_dim"] = pt_cfg["hidden_size"]
112+
cfg["intermediate_dim"] = pt_cfg["intermediate_size"]
113+
cfg["dropout"] = pt_cfg["hidden_dropout_prob"]
114+
cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"]
115+
cfg["bucket_size"] = pt_cfg["position_buckets"]
116+
print("Config:", cfg)
117+
118+
hf_wts = hf_model.state_dict()
119+
print("Original weights:")
120+
print(
121+
str(hf_wts.keys())
122+
.replace(", ", "\n")
123+
.replace("odict_keys([", "")
124+
.replace("]", "")
125+
.replace(")", "")
126+
)
127+
128+
keras_nlp_model.get_layer("token_embedding").embeddings.assign(
129+
hf_wts["embeddings.word_embeddings.weight"]
130+
)
131+
keras_nlp_model.get_layer("embeddings_layer_norm").gamma.assign(
132+
hf_wts["embeddings.LayerNorm.weight"]
133+
)
134+
keras_nlp_model.get_layer("embeddings_layer_norm").beta.assign(
135+
hf_wts["embeddings.LayerNorm.bias"]
136+
)
137+
keras_nlp_model.get_layer("rel_embedding").rel_embeddings.assign(
138+
hf_wts["encoder.rel_embeddings.weight"]
139+
)
140+
keras_nlp_model.get_layer("rel_embedding").layer_norm.gamma.assign(
141+
hf_wts["encoder.LayerNorm.weight"]
142+
)
143+
keras_nlp_model.get_layer("rel_embedding").layer_norm.beta.assign(
144+
hf_wts["encoder.LayerNorm.bias"]
145+
)
146+
147+
for i in range(keras_nlp_model.num_layers):
148+
# Q,K,V
149+
keras_nlp_model.get_layer(
150+
f"disentangled_attention_encoder_layer_{i}"
151+
)._self_attention_layer._query_dense.kernel.assign(
152+
hf_wts[f"encoder.layer.{i}.attention.self.query_proj.weight"]
153+
.numpy()
154+
.T.reshape((cfg["hidden_dim"], cfg["num_heads"], -1))
155+
)
156+
keras_nlp_model.get_layer(
157+
f"disentangled_attention_encoder_layer_{i}"
158+
)._self_attention_layer._query_dense.bias.assign(
159+
hf_wts[f"encoder.layer.{i}.attention.self.query_proj.bias"]
160+
.reshape((cfg["num_heads"], -1))
161+
.numpy()
162+
)
163+
164+
keras_nlp_model.get_layer(
165+
f"disentangled_attention_encoder_layer_{i}"
166+
)._self_attention_layer._key_dense.kernel.assign(
167+
hf_wts[f"encoder.layer.{i}.attention.self.key_proj.weight"]
168+
.numpy()
169+
.T.reshape((cfg["hidden_dim"], cfg["num_heads"], -1))
170+
)
171+
keras_nlp_model.get_layer(
172+
f"disentangled_attention_encoder_layer_{i}"
173+
)._self_attention_layer._key_dense.bias.assign(
174+
hf_wts[f"encoder.layer.{i}.attention.self.key_proj.bias"]
175+
.reshape((cfg["num_heads"], -1))
176+
.numpy()
177+
)
178+
179+
keras_nlp_model.get_layer(
180+
f"disentangled_attention_encoder_layer_{i}"
181+
)._self_attention_layer._value_dense.kernel.assign(
182+
hf_wts[f"encoder.layer.{i}.attention.self.value_proj.weight"]
183+
.numpy()
184+
.T.reshape((cfg["hidden_dim"], cfg["num_heads"], -1))
185+
)
186+
keras_nlp_model.get_layer(
187+
f"disentangled_attention_encoder_layer_{i}"
188+
)._self_attention_layer._value_dense.bias.assign(
189+
hf_wts[f"encoder.layer.{i}.attention.self.value_proj.bias"]
190+
.reshape((cfg["num_heads"], -1))
191+
.numpy()
192+
)
193+
194+
# Attn output.
195+
keras_nlp_model.get_layer(
196+
f"disentangled_attention_encoder_layer_{i}"
197+
)._self_attention_layer._output_dense.kernel.assign(
198+
hf_wts[f"encoder.layer.{i}.attention.output.dense.weight"]
199+
.transpose(1, 0)
200+
.numpy()
201+
)
202+
keras_nlp_model.get_layer(
203+
f"disentangled_attention_encoder_layer_{i}"
204+
)._self_attention_layer._output_dense.bias.assign(
205+
hf_wts[f"encoder.layer.{i}.attention.output.dense.bias"].numpy()
206+
)
207+
208+
keras_nlp_model.get_layer(
209+
f"disentangled_attention_encoder_layer_{i}"
210+
)._self_attention_layernorm.gamma.assign(
211+
hf_wts[
212+
f"encoder.layer.{i}.attention.output.LayerNorm.weight"
213+
].numpy()
214+
)
215+
keras_nlp_model.get_layer(
216+
f"disentangled_attention_encoder_layer_{i}"
217+
)._self_attention_layernorm.beta.assign(
218+
hf_wts[f"encoder.layer.{i}.attention.output.LayerNorm.bias"].numpy()
219+
)
220+
221+
# Intermediate FF layer.
222+
keras_nlp_model.get_layer(
223+
f"disentangled_attention_encoder_layer_{i}"
224+
)._feedforward_intermediate_dense.kernel.assign(
225+
hf_wts[f"encoder.layer.{i}.intermediate.dense.weight"]
226+
.transpose(1, 0)
227+
.numpy()
228+
)
229+
keras_nlp_model.get_layer(
230+
f"disentangled_attention_encoder_layer_{i}"
231+
)._feedforward_intermediate_dense.bias.assign(
232+
hf_wts[f"encoder.layer.{i}.intermediate.dense.bias"].numpy()
233+
)
234+
235+
# Output FF layer.
236+
keras_nlp_model.get_layer(
237+
f"disentangled_attention_encoder_layer_{i}"
238+
)._feedforward_output_dense.kernel.assign(
239+
hf_wts[f"encoder.layer.{i}.output.dense.weight"].numpy().T
240+
)
241+
keras_nlp_model.get_layer(
242+
f"disentangled_attention_encoder_layer_{i}"
243+
)._feedforward_output_dense.bias.assign(
244+
hf_wts[f"encoder.layer.{i}.output.dense.bias"].numpy()
245+
)
246+
247+
keras_nlp_model.get_layer(
248+
f"disentangled_attention_encoder_layer_{i}"
249+
)._feedforward_layernorm.gamma.assign(
250+
hf_wts[f"encoder.layer.{i}.output.LayerNorm.weight"].numpy()
251+
)
252+
keras_nlp_model.get_layer(
253+
f"disentangled_attention_encoder_layer_{i}"
254+
)._feedforward_layernorm.beta.assign(
255+
hf_wts[f"encoder.layer.{i}.output.LayerNorm.bias"].numpy()
256+
)
257+
258+
# Save the model.
259+
print(f"\n-> Save KerasNLP model weights to `{preset}.h5`.")
260+
keras_nlp_model.save_weights(f"{preset}.h5")
261+
262+
return keras_nlp_model
263+
264+
265+
def check_output(
266+
preset,
267+
keras_nlp_preprocessor,
268+
keras_nlp_model,
269+
hf_tokenizer,
270+
hf_model,
271+
):
272+
print("\n-> Check the outputs.")
273+
sample_text = ["cricket is awesome, easily the best sport in the world!"]
274+
275+
# KerasNLP
276+
keras_nlp_inputs = keras_nlp_preprocessor(tf.constant(sample_text))
277+
keras_nlp_output = keras_nlp_model.predict(keras_nlp_inputs)
278+
279+
# HF
280+
hf_inputs = hf_tokenizer(
281+
sample_text, padding="longest", return_tensors="pt"
282+
)
283+
hf_output = hf_model(**hf_inputs).last_hidden_state
284+
285+
print("KerasNLP output:", keras_nlp_output[0, 0, :10])
286+
print("HF output:", hf_output[0, 0, :10])
287+
print("Difference:", np.mean(keras_nlp_output - hf_output.detach().numpy()))
288+
289+
# Show the MD5 checksum of the model weights.
290+
print("Model md5sum: ", get_md5_checksum(f"./{preset}.h5"))
291+
292+
293+
def main(_):
294+
hf_model_name = PRESET_MAP[FLAGS.preset]
295+
296+
download_files(FLAGS.preset, hf_model_name)
297+
298+
keras_nlp_preprocessor, hf_tokenizer = define_preprocessor(
299+
FLAGS.preset, hf_model_name
300+
)
301+
302+
print("\n-> Load KerasNLP model.")
303+
keras_nlp_model = DebertaV3Backbone.from_preset(
304+
FLAGS.preset, load_weights=False
305+
)
306+
307+
print("\n-> Load HF model.")
308+
hf_model = transformers.AutoModel.from_pretrained(hf_model_name)
309+
hf_model.eval()
310+
311+
keras_nlp_model = convert_checkpoints(
312+
FLAGS.preset, keras_nlp_model, hf_model
313+
)
314+
315+
check_output(
316+
FLAGS.preset,
317+
keras_nlp_preprocessor,
318+
keras_nlp_model,
319+
hf_tokenizer,
320+
hf_model,
321+
)
322+
323+
324+
if __name__ == "__main__":
325+
flags.mark_flag_as_required("preset")
326+
app.run(main)

0 commit comments

Comments
 (0)