CLI: convert sharded PT models#17959
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| for path in pytorch_checkpoint_path: | ||
| pt_path = os.path.abspath(path) | ||
| logger.info(f"Loading PyTorch weights from {pt_path}") | ||
| pt_state_dict.update(torch.load(pt_path, map_location="cpu")) |
There was a problem hiding this comment.
That's super nice 👍🏻
There was a problem hiding this comment.
That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.
There was a problem hiding this comment.
haha yes, I had to spin up a machine with >100GB of RAM to convert the RegNets 😬
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM thanks for working on that!
|
BTW could we add 2 tests, |
|
TF shards -> PT probably won't work, but I will add the test for PT shards -> TF 👍 |
sgugger
left a comment
There was a problem hiding this comment.
Nice improvements, thanks!
| for path in pytorch_checkpoint_path: | ||
| pt_path = os.path.abspath(path) | ||
| logger.info(f"Loading PyTorch weights from {pt_path}") | ||
| pt_state_dict.update(torch.load(pt_path, map_location="cpu")) |
There was a problem hiding this comment.
That is a nice first step, but ideally, we'd want to convert the shards one by one to avoid using too much RAM and be able to convert LLMs checkpoints without needing a battle station.
| elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): | ||
| # Load from a sharded PyTorch checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) | ||
| is_sharded = True |
There was a problem hiding this comment.
Nice addition, maybe we should also support loading from a remote sharded checkpoint with from_pt=True? (It should be its own PR if we decide to support this.)
* sharded conversion; add flag to control max hidden error * better hidden name matching * Add test: load TF from PT shards * fix test (PT data must be local)
What does this PR do?
This PR adds a major upgrade and a minor change to the
pt-to-tfCLI.Major upgrade: we can now convert sharded PT models. It updates how the
from_ptloading works so as to be able to load from shards. It also updates how thept-to-tfCLI stores the models, so it uses sharding capabilities when needed.Minor change: adds a flag to control the maximum hidden layer admissible error. It is relatively common to find models where the outputs from the PT and TF models are nearly the same, but the hidden layers have a larger mismatch. This flag allows us to temporarily increase the admissible error if the model seems to be behaving properly (for instance, all RegNet models had a hidden layer difference between 1e-4 and 1e-2, but the outputs were behaving properly).
Example of sharded TF model PR, using the updated tools: https://huggingface.co/facebook/regnet-y-10b-seer-in1k/discussions/1