diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 88861b10a8c3..c77fa5f7fb58 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -108,6 +108,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): continue pointer = model array = tf_weights[txt_name] + for m_name in name: if re.fullmatch(r"[A-Za-z]+_\d+", m_name): scope_names = re.split(r"_(\d+)", m_name) @@ -115,12 +116,30 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): scope_names = [m_name] if scope_names[0] in ["kernel", "scale", "embedding"]: pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") elif scope_names[0] == "scale": pointer = getattr(pointer, "weight") elif scope_names[0] == "output_bias" or scope_names[0] == "beta": pointer = getattr(pointer, "bias") elif scope_names[0] == "squad": pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") else: try: pointer = getattr(pointer, scope_names[0])