Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

import contextlib
import logging
import os
import re
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
Expand All @@ -41,6 +39,7 @@
set_seed,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.modeling_layers import GenericForSequenceClassification
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

Expand All @@ -63,23 +62,18 @@
# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to
# avoid confusing users.
@contextmanager
def suppress_from_pretrained_warning(logger: logging.Logger):
pattern = re.compile(
r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: "
r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and "
r"inference\.$"
)

class _Filter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return not pattern.search(record.getMessage())

f = _Filter()
logger.addFilter(f)
def ignore_seqcls_score_missing_key():
# Scoped override: ignore only the expected seq-clf head key.
old = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_missing", None)
merged = list(old) if old is not None else []
pattern = r"^score\.weight$"
if pattern not in merged:
merged.append(pattern)
GenericForSequenceClassification._keys_to_ignore_on_load_missing = merged
try:
yield
finally:
logger.removeFilter(f)
GenericForSequenceClassification._keys_to_ignore_on_load_missing = old


def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]:
Expand Down Expand Up @@ -310,7 +304,8 @@ def __init__(
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs)
with ignore_seqcls_score_missing_key():
model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs)
else:
if args.model_init_kwargs is not None:
logger.warning(
Expand Down
Loading