diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 1a105abf6a7..cd079499b51 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -62,14 +62,17 @@ logger = get_logger(__name__) -# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly -# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to -# avoid confusing users. +# Loading a CausalLM checkpoint into AutoModelForSequenceClassification triggers harmless warnings: +# - MISSING score.weight : the new seq-clf head was not in the checkpoint and is randomly initialized. +# - UNEXPECTED lm_head.weight: the causal LM head is in the checkpoint but absent from seq-clf (>= 4.57.0 only). +# Both are expected consequences of intentional cross-architecture loading. We suppress them to avoid +# confusing users. # Old approach using logging filter (for transformers < 4.57.0) +# Note: in transformers < 4.57.0, only the MISSING score.weight warning is emitted; lm_head.weight is not reported. @contextmanager -def suppress_from_pretrained_warning(logger: logging.Logger): +def _suppress_seqcls_cross_arch_keys(logger: logging.Logger): pattern = re.compile( r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: " r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and " @@ -90,18 +93,27 @@ def filter(self, record: logging.LogRecord) -> bool: # New approach using scoped override (for transformers >= 4.57.0) @contextmanager -def ignore_seqcls_score_missing_key(): - # Scoped override: ignore only the expected seq-clf head key. - old = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_missing", None) - merged = list(old) if old is not None else [] - pattern = r"^score\.weight$" - if pattern not in merged: - merged.append(pattern) - GenericForSequenceClassification._keys_to_ignore_on_load_missing = merged +def _ignore_seqcls_cross_arch_keys(): + # Scoped override: ignore the expected seq-clf head key (newly added) and the causal LM head + # key (present in the checkpoint but absent from seq-clf). + old_missing = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_missing", None) + old_unexpected = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_unexpected", None) + + merged_missing = list(old_missing) if old_missing is not None else [] + if r"^score\.weight$" not in merged_missing: + merged_missing.append(r"^score\.weight$") + + merged_unexpected = list(old_unexpected) if old_unexpected is not None else [] + if r"^lm_head\." not in merged_unexpected: + merged_unexpected.append(r"^lm_head\.") + + GenericForSequenceClassification._keys_to_ignore_on_load_missing = merged_missing + GenericForSequenceClassification._keys_to_ignore_on_load_unexpected = merged_unexpected try: yield finally: - GenericForSequenceClassification._keys_to_ignore_on_load_missing = old + GenericForSequenceClassification._keys_to_ignore_on_load_missing = old_missing + GenericForSequenceClassification._keys_to_ignore_on_load_unexpected = old_unexpected # Version-aware wrapper that chooses the appropriate approach @@ -110,12 +122,12 @@ def suppress_seqcls_warning(): # Use the new approach for transformers >= 4.57.0, old approach for earlier versions # The old approach is needed for 4.56.2 to avoid meta tensor issues with device_map=None if Version(transformers.__version__) >= Version("4.57.0"): - with ignore_seqcls_score_missing_key(): + with _ignore_seqcls_cross_arch_keys(): yield else: # Get the transformers logger transformers_logger = logging.getLogger("transformers.modeling_utils") - with suppress_from_pretrained_warning(transformers_logger): + with _suppress_seqcls_cross_arch_keys(transformers_logger): yield