Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class PretrainedConfig(object):
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case
the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig`
like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`.
- **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at
dictionary outputs of the model during inference.
- **output_keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking
at dictionary outputs of the model during inference.

Args:
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class BartConfig(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "bart"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ctrl/configuration_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class CTRLConfig(PretrainedConfig):
"""

model_type = "ctrl"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GPT2Config(PretrainedConfig):
"""

model_type = "gpt2"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/configuration_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ class MarianConfig(BartConfig):
"""

model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/configuration_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ class MBartConfig(BartConfig):
"""

model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]
2 changes: 1 addition & 1 deletion src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MT5Config(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/pegasus/configuration_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,5 @@ class PegasusConfig(BartConfig):
"""

model_type = "pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]
# The implementation of the config object is in BartConfig
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ProphetNetConfig(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/reformer/configuration_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class ReformerConfig(PretrainedConfig):
>>> configuration = model.config
"""
model_type = "reformer"
keys_to_ignore_at_inference = ["past_buckets_states"]
output_keys_to_ignore_at_inference = ["past_buckets_states"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/t5/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class T5Config(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
output_keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TransfoXLConfig(PretrainedConfig):
"""

model_type = "transfo-xl"
keys_to_ignore_at_inference = ["mems"]
output_keys_to_ignore_at_inference = ["mems"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/xlnet/configuration_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class XLNetConfig(PretrainedConfig):
"""

model_type = "xlnet"
keys_to_ignore_at_inference = ["mems"]
output_keys_to_ignore_at_inference = ["mems"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,7 @@ def prediction_step(
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
ignore_keys = getattr(self.model.config, "output_keys_to_ignore_at_inference", [])
else:
ignore_keys = []

Expand Down