Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework how PreTrainedModel.from_pretrained handles its arguments #866

Merged
merged 4 commits into from
Jul 23, 2019
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions pytorch_transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def save_pretrained(self, save_directory):
self.to_json_file(output_config_file)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.

Params:
anlsh marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -105,6 +105,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):

"""
cache_dir = kwargs.pop('cache_dir', None)
return_unused_args = kwargs.pop('return_unused_args', False)

if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
Expand Down Expand Up @@ -148,7 +149,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
kwargs.pop(key, None)

logger.info("Model config %s", config)
return config
if return_unused_args:
return config, kwargs
else:
return config

@classmethod
def from_dict(cls, json_object):
Expand Down Expand Up @@ -305,7 +309,7 @@ def save_pretrained(self, save_directory):
torch.save(model_to_save.state_dict(), output_model_file)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
Expand Down Expand Up @@ -336,9 +340,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**model_args**: (`optional`) Sequence:
All positional arguments will be passed to the underlying model's __init__ function
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.

If config is None, then **kwargs will be passed to the model.
anlsh marked this conversation as resolved.
Show resolved Hide resolved
If said key is *not* present, then kwargs will be used to
override any keys shared with the default configuration for the
given pretrained_model_name_or_path, and only the unshared
key/value pairs will be passed to the model.

Examples::

Expand All @@ -359,7 +371,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):

# Load config
if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_args=True,
**kwargs
)
else:
model_kwargs = kwargs

# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
Expand Down Expand Up @@ -400,7 +418,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
archive_file, resolved_archive_file))

# Instantiate model.
model = cls(config)
model = cls(config, *model_args, **model_kwargs)

if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
Expand Down Expand Up @@ -530,7 +548,7 @@ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
hidden states of the first tokens for the labeled span.
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span:
position of the first token for the labeled span:
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked.
Expand Down Expand Up @@ -717,7 +735,7 @@ class SequenceSummary(nn.Module):
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
"""
Expand Down