Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@
logger = get_logger(__name__)


# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly
# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to
# avoid confusing users.
# Loading a CausalLM checkpoint into AutoModelForSequenceClassification triggers harmless warnings:
# - MISSING score.weight : the new seq-clf head was not in the checkpoint and is randomly initialized.
# - UNEXPECTED lm_head.weight: the causal LM head is in the checkpoint but absent from seq-clf (>= 4.57.0 only).
# Both are expected consequences of intentional cross-architecture loading. We suppress them to avoid
# confusing users.


# Old approach using logging filter (for transformers < 4.57.0)
# Note: in transformers < 4.57.0, only the MISSING score.weight warning is emitted; lm_head.weight is not reported.
@contextmanager
def suppress_from_pretrained_warning(logger: logging.Logger):
def _suppress_seqcls_cross_arch_keys(logger: logging.Logger):
pattern = re.compile(
r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: "
r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and "
Expand All @@ -90,18 +93,27 @@ def filter(self, record: logging.LogRecord) -> bool:

# New approach using scoped override (for transformers >= 4.57.0)
@contextmanager
def ignore_seqcls_score_missing_key():
# Scoped override: ignore only the expected seq-clf head key.
old = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_missing", None)
merged = list(old) if old is not None else []
pattern = r"^score\.weight$"
if pattern not in merged:
merged.append(pattern)
GenericForSequenceClassification._keys_to_ignore_on_load_missing = merged
def _ignore_seqcls_cross_arch_keys():
# Scoped override: ignore the expected seq-clf head key (newly added) and the causal LM head
# key (present in the checkpoint but absent from seq-clf).
old_missing = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_missing", None)
old_unexpected = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_unexpected", None)

merged_missing = list(old_missing) if old_missing is not None else []
if r"^score\.weight$" not in merged_missing:
merged_missing.append(r"^score\.weight$")

merged_unexpected = list(old_unexpected) if old_unexpected is not None else []
if r"^lm_head\." not in merged_unexpected:
merged_unexpected.append(r"^lm_head\.")

GenericForSequenceClassification._keys_to_ignore_on_load_missing = merged_missing
GenericForSequenceClassification._keys_to_ignore_on_load_unexpected = merged_unexpected
try:
yield
finally:
GenericForSequenceClassification._keys_to_ignore_on_load_missing = old
GenericForSequenceClassification._keys_to_ignore_on_load_missing = old_missing
GenericForSequenceClassification._keys_to_ignore_on_load_unexpected = old_unexpected


# Version-aware wrapper that chooses the appropriate approach
Expand All @@ -110,12 +122,12 @@ def suppress_seqcls_warning():
# Use the new approach for transformers >= 4.57.0, old approach for earlier versions
# The old approach is needed for 4.56.2 to avoid meta tensor issues with device_map=None
if Version(transformers.__version__) >= Version("4.57.0"):
with ignore_seqcls_score_missing_key():
with _ignore_seqcls_cross_arch_keys():
yield
else:
# Get the transformers logger
transformers_logger = logging.getLogger("transformers.modeling_utils")
with suppress_from_pretrained_warning(transformers_logger):
with _suppress_seqcls_cross_arch_keys(transformers_logger):
yield


Expand Down
Loading