diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 59846a892533..888277101412 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -163,6 +163,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") + if "running_var" in key: + new_key = key.replace("running_var", "moving_variance") + if "running_mean" in key: + new_key = key.replace("running_mean", "moving_mean") if new_key: old_keys.append(key) new_keys.append(new_key)