Skip to content
16 changes: 13 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union


# Integrations must be imported before ML frameworks:
Expand Down Expand Up @@ -358,6 +358,9 @@ class Trainer:
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
A tuple containing the optimizer class and keyword arguments to use.
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
Expand Down Expand Up @@ -401,7 +404,8 @@ def __init__(
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
optimizer_cls_and_kwargs: Optional[Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
if args is None:
Expand Down Expand Up @@ -597,6 +601,9 @@ def __init__(
self.compute_metrics = compute_metrics
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
raise RuntimeError(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Expand Down Expand Up @@ -1165,7 +1172,10 @@ def create_optimizer(self):
},
]

optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)

# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
Expand Down