From b1c70dd1e9f46994524e874a5c3b4285e39409ed Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 13 Nov 2020 19:12:59 +0000 Subject: [PATCH 1/2] fix load weights --- src/transformers/modeling_t5.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 88861b10a8c3..26bf0b25cdf2 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -108,6 +108,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): continue pointer = model array = tf_weights[txt_name] + for m_name in name: if re.fullmatch(r"[A-Za-z]+_\d+", m_name): scope_names = re.split(r"_(\d+)", m_name) @@ -115,12 +116,31 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): scope_names = [m_name] if scope_names[0] in ["kernel", "scale", "embedding"]: pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") elif scope_names[0] == "scale": pointer = getattr(pointer, "weight") elif scope_names[0] == "output_bias" or scope_names[0] == "beta": pointer = getattr(pointer, "bias") elif scope_names[0] == "squad": pointer = getattr(pointer, "classifier") + + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") else: try: pointer = getattr(pointer, scope_names[0]) From 05e41b9faa15162e7c3c1219d2abbc98b4f1dbc4 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 13 Nov 2020 19:23:28 +0000 Subject: [PATCH 2/2] delete line --- src/transformers/modeling_t5.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 26bf0b25cdf2..c77fa5f7fb58 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -136,7 +136,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): pointer = getattr(pointer, "bias") elif scope_names[0] == "squad": pointer = getattr(pointer, "classifier") - elif scope_names[0] == "decoder" and name[1] == "logits": continue elif scope_names[0] == "logits":