diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 29c4d67510c3..dcedd2946b85 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2905,9 +2905,10 @@ def push_to_hub( use_temp_dir: Optional[bool] = None, commit_message: Optional[str] = None, private: Optional[bool] = None, - use_auth_token: Optional[Union[bool, str]] = None, max_shard_size: Optional[Union[int, str]] = "10GB", - **model_card_kwargs, + use_auth_token: Optional[Union[bool, str]] = None, + create_pr: bool = False, + **base_model_card_args, ) -> str: """ Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. @@ -2931,8 +2932,8 @@ def push_to_hub( Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - model_card_kwargs: - Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method. + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. Examples: @@ -2948,15 +2949,15 @@ def push_to_hub( model.push_to_hub("huggingface/my-finetuned-bert") ``` """ - if "repo_path_or_name" in model_card_kwargs: + if "repo_path_or_name" in base_model_card_args: warnings.warn( "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " "`repo_id` instead." ) - repo_id = model_card_kwargs.pop("repo_path_or_name") + repo_id = base_model_card_args.pop("repo_path_or_name") # Deprecation warning will be sent after for repo_url and organization - repo_url = model_card_kwargs.pop("repo_url", None) - organization = model_card_kwargs.pop("organization", None) + repo_url = base_model_card_args.pop("repo_url", None) + organization = base_model_card_args.pop("organization", None) if os.path.isdir(repo_id): working_dir = repo_id @@ -2982,11 +2983,16 @@ def push_to_hub( "output_dir": work_dir, "model_name": Path(repo_id).name, } - base_model_card_args.update(model_card_kwargs) + base_model_card_args.update(base_model_card_args) self.create_model_card(**base_model_card_args) self._upload_modified_files( - work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=use_auth_token, + create_pr=create_pr, ) @classmethod diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 9cfd314dfcbf..42db9aea2698 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -85,6 +85,7 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, + PreTrainedModel, PushToHubCallback, RagRetriever, TFAutoModel, @@ -92,6 +93,7 @@ TFBertForMaskedLM, TFBertForSequenceClassification, TFBertModel, + TFPreTrainedModel, TFRagModel, TFSharedEmbeddings, ) @@ -2466,6 +2468,7 @@ def test_push_to_hub(self): break self.assertTrue(models_equal) + @is_pt_tf_cross_test def test_push_to_hub_callback(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 @@ -2489,6 +2492,12 @@ def test_push_to_hub_callback(self): break self.assertTrue(models_equal) + tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters) + tf_push_to_hub_params.pop("base_model_card_args") + pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters) + pt_push_to_hub_params.pop("deprecated_kwargs") + self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params) + def test_push_to_hub_in_organization(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37