diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index fad40db2fa25..3ec971325075 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -494,8 +494,8 @@ - The model is a model provided by the library (loaded with the `shortcut name` string of a pretrained model). - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded - by suppling the save directory. - - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. state_dict (`Dict[str, torch.Tensor]`, `optional`): A state dictionary to use instead of a state dictionary loaded from saved weights file. diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f33275c6164f..ab3523b8724e 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -550,8 +550,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - The model is a model provided by the library (loaded with the `shortcut name` string of a pretrained model). - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded - by suppling the save directory. - - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch state_dict save file (see docstring of diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 195df8681078..b7a87f99a179 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -784,8 +784,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - The model is a model provided by the library (loaded with the `shortcut name` string of a pretrained model). - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded - by suppling the save directory. - - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`): A state dictionary to use instead of a state dictionary loaded from saved weights file. diff --git a/tests/test_doc_samples.py b/tests/test_doc_samples.py index 716537c95584..8e945bae9db9 100644 --- a/tests/test_doc_samples.py +++ b/tests/test_doc_samples.py @@ -36,8 +36,8 @@ def analyze_directory( self, directory: Path, identifier: Union[str, None] = None, - ignore_files: Union[List[str], None] = [], - n_identifier: Union[str, None] = None, + ignore_files: Union[List[str], None] = None, + n_identifier: Union[str, List[str], None] = None, only_modules: bool = True, ): """ @@ -45,7 +45,7 @@ def analyze_directory( the doctests in those files Args: - directory (:obj:`str`): Directory containing the files + directory (:obj:`Path`): Directory containing the files identifier (:obj:`str`): Will parse files containing this ignore_files (:obj:`List[str]`): List of files to skip n_identifier (:obj:`str` or :obj:`List[str]`): Will not parse files containing this/these identifiers. @@ -63,6 +63,7 @@ def analyze_directory( else: files = [file for file in files if n_identifier not in file] + ignore_files = ignore_files or [] ignore_files.append("__init__.py") files = [file for file in files if file not in ignore_files] @@ -71,8 +72,8 @@ def analyze_directory( print("Testing", file) if only_modules: + module_identifier = file.split(".")[0] try: - module_identifier = file.split(".")[0] module_identifier = getattr(transformers, module_identifier) suite = doctest.DocTestSuite(module_identifier) result = unittest.TextTestRunner().run(suite) @@ -84,7 +85,7 @@ def analyze_directory( self.assertIs(result.failed, 0) def test_modeling_examples(self): - transformers_directory = "src/transformers" + transformers_directory = Path("src/transformers") files = "modeling" ignore_files = [ "modeling_ctrl.py",