Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 75 additions & 33 deletions src/transformers/commands/pt_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
is_tf_available,
is_torch_available,
)
from ..utils import logging
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from . import BaseTransformersCLICommand


Expand All @@ -48,7 +48,6 @@


MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
TF_WEIGHTS_NAME = "tf_model.h5"


def convert_command_factory(args: Namespace):
Expand All @@ -58,7 +57,13 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return PTtoTFCommand(
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
args.model_name,
args.local_dir,
args.max_hidden_error,
args.new_weights,
args.no_pr,
args.push,
args.extra_commit_description,
)


Expand Down Expand Up @@ -90,6 +95,15 @@ def register_subcommand(parser: ArgumentParser):
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--max-hidden-error",
type=float,
default=MAX_ERROR,
help=(
f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden"
" layers outputs will be used for downstream applications, avoid increasing this tolerance."
),
)
train_parser.add_argument(
"--new-weights",
action="store_true",
Expand All @@ -112,14 +126,10 @@ def register_subcommand(parser: ArgumentParser):
train_parser.set_defaults(func=convert_command_factory)

@staticmethod
def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input):
def find_pt_tf_differences(pt_outputs, tf_outputs):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor
differences.
Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)

# 1. All output attributes must be the same
pt_out_attrs = set(pt_outputs.keys())
tf_out_attrs = set(tf_outputs.keys())
Expand Down Expand Up @@ -158,6 +168,7 @@ def __init__(
self,
model_name: str,
local_dir: str,
max_hidden_error: float,
new_weights: bool,
no_pr: bool,
push: bool,
Expand All @@ -167,6 +178,7 @@ def __init__(
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._max_hidden_error = max_hidden_error
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push
Expand Down Expand Up @@ -260,34 +272,49 @@ def run(self):
pt_model = pt_class.from_pretrained(self._local_dir)
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config)
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint

tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)

# Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
max_crossload_diff = max(crossload_differences.values())
if max_crossload_diff > MAX_ERROR:
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
max_crossload_output_diff = max(output_differences.values())
max_crossload_hidden_diff = max(hidden_differences.values())
if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error:
raise ValueError(
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR]
)
"The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
+ f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
)

# Save the weights in a TF format (if needed) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path) or self._new_weights:
tf_from_pt_model.save_weights(tf_weights_path)
tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
tf_from_pt_model.save_pretrained(self._local_dir)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint

tf_model = tf_class.from_pretrained(self._local_dir)
conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input)
max_conversion_diff = max(conversion_differences.values())
if max_conversion_diff > MAX_ERROR:
tf_outputs = tf_model(**tf_input, output_hidden_states=True)

conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
max_conversion_output_diff = max(output_differences.values())
max_conversion_hidden_diff = max(hidden_differences.values())
if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error:
raise ValueError(
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f" tensor differences above the error threshold ({MAX_ERROR}):\n"
+ "\n".join(
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR]
)
"The converted TensorFlow model has different outputs, something went wrong!\n"
+ f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
)

commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
Expand All @@ -300,16 +327,31 @@ def run(self):
self._logger.warn("Uploading the weights into a new PR...")
commit_descrition = (
"Model converted by the [`transformers`' `pt_to_tf`"
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
"All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n"
f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
)
if self._extra_commit_description:
commit_descrition += "\n\n" + self._extra_commit_description

# sharded model -> adds all related files (index and .h5 shards)
if os.path.exists(tf_weights_index_path):
operations = [
CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
]
for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
operations += [
CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
]
else:
operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]

hub_pr_url = create_commit(
repo_id=self._model_name,
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
operations=operations,
commit_message=commit_message,
commit_description=commit_descrition,
repo_type="model",
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,17 @@ def load_pytorch_checkpoint_in_tf2_model(
)
raise

pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
# Treats a single file as a collection of shards with 1 shard.
if isinstance(pytorch_checkpoint_path, str):
pytorch_checkpoint_path = [pytorch_checkpoint_path]

# Loads all shards into a single state dictionary
pt_state_dict = {}
for path in pytorch_checkpoint_path:
pt_path = os.path.abspath(path)
logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
Comment on lines +126 to +129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's super nice 👍🏻

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.

Copy link
Contributor Author

@gante gante Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha yes, I had to spin up a machine with >100GB of RAM to convert the RegNets 😬


pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")

return load_pytorch_weights_in_tf2_model(
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput,
Expand Down Expand Up @@ -2157,11 +2158,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
Comment on lines +2161 to +2164
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition, maybe we should also support loading from a remote sharded checkpoint with from_pt=True? (It should be its own PR if we decide to support this.)

elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
# Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from datasets import Dataset

from huggingface_hub import HfFolder, delete_repo, set_access_token
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -1966,6 +1966,16 @@ def test_checkpoint_sharding_from_hub(self):
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())

@is_pt_tf_cross_test
def test_checkpoint_sharding_local_from_pt(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())

def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential(
Expand Down