diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py old mode 100644 new mode 100755 index dfcdbee42d50..eeebb3728ee8 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -68,11 +68,17 @@ def convert(): arrays.append(array) for name, array in zip(names, arrays): - name = name[5:] # skip "bert/" + if not name.startswith("bert"): + print("Skipping {}".format(name)) + continue + else: + name = name.replace("bert/", "") # skip "bert/" print("Loading {}".format(name)) name = name.split('/') - if name[0] in ['redictions', 'eq_relationship']: - print("Skipping") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": + print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: