Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,8 @@
- The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
state_dict (`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
- The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch state_dict save file (see docstring of
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
Expand Down
11 changes: 6 additions & 5 deletions tests/test_doc_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def analyze_directory(
self,
directory: Path,
identifier: Union[str, None] = None,
ignore_files: Union[List[str], None] = [],
n_identifier: Union[str, None] = None,
ignore_files: Union[List[str], None] = None,
n_identifier: Union[str, List[str], None] = None,
only_modules: bool = True,
):
"""
Runs through the specific directory, looking for the files identified with `identifier`. Executes
the doctests in those files

Args:
directory (:obj:`str`): Directory containing the files
directory (:obj:`Path`): Directory containing the files
identifier (:obj:`str`): Will parse files containing this
ignore_files (:obj:`List[str]`): List of files to skip
n_identifier (:obj:`str` or :obj:`List[str]`): Will not parse files containing this/these identifiers.
Expand All @@ -63,6 +63,7 @@ def analyze_directory(
else:
files = [file for file in files if n_identifier not in file]

ignore_files = ignore_files or []
ignore_files.append("__init__.py")
files = [file for file in files if file not in ignore_files]

Expand All @@ -71,8 +72,8 @@ def analyze_directory(
print("Testing", file)

if only_modules:
module_identifier = file.split(".")[0]
try:
module_identifier = file.split(".")[0]
module_identifier = getattr(transformers, module_identifier)
suite = doctest.DocTestSuite(module_identifier)
result = unittest.TextTestRunner().run(suite)
Expand All @@ -84,7 +85,7 @@ def analyze_directory(
self.assertIs(result.failed, 0)

def test_modeling_examples(self):
transformers_directory = "src/transformers"
transformers_directory = Path("src/transformers")
files = "modeling"
ignore_files = [
"modeling_ctrl.py",
Expand Down