Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4461,6 +4461,7 @@ def push_to_hub(
commit_message: Optional[str] = "End of training",
blocking: bool = True,
token: Optional[str] = None,
revision: Optional[str] = None,
**kwargs,
) -> str:
"""
Expand All @@ -4473,6 +4474,8 @@ def push_to_hub(
Whether the function should return only when the `git push` has finished.
token (`str`, *optional*, defaults to `None`):
Token with write permission to overwrite Trainer's original args.
revision (`str`, *optional*):
The git revision to commit from. Defaults to the head of the "main" branch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it the branch to commit to - rather than the commit to commit from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this description comes from the hub upload_folder api. I was only going to use it for branches and tested with that, so not sure if it works for other commit hashes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to [`~Trainer.create_model_card`].

Expand Down Expand Up @@ -4526,6 +4529,7 @@ def push_to_hub(
token=token,
run_as_future=not blocking,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
revision=revision,
)

#
Expand Down
21 changes: 20 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from unittest.mock import Mock, patch

import numpy as np
from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
from parameterized import parameterized
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -3933,6 +3933,25 @@ def test_push_to_hub_tags(self):
model_card = ModelCard.load(repo_name)
self.assertTrue("test-trainer-tags" in model_card.data.tags)

def test_push_to_hub_with_revision(self):
# Checks if `trainer.push_to_hub()` works correctly by adding revision
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-revision"),
push_to_hub=True,
hub_token=self._token,
)
branch = "v1.0"
create_branch(repo_id=trainer.hub_model_id, branch=branch, token=self._token, exist_ok=True)
url = trainer.push_to_hub(revision=branch)

# Extract branch from the url
re_search = re.search(r"tree/([^/]+)/", url)
self.assertIsNotNone(re_search)

branch_name = re_search.groups()[0]
self.assertEqual(branch_name, branch)


@require_torch
@require_optuna
Expand Down