Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available():
if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True
Expand Down Expand Up @@ -490,11 +490,9 @@ def setup_wandb(self):
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
combined_dict = {**self.args.to_sanitized_dict()}
if isinstance(self.model, PreTrainedModel):
combined_dict = {**self.model.config.to_dict(), **combined_dict}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
Expand Down Expand Up @@ -533,7 +531,8 @@ def setup_comet(self):
if experiment is not None:
experiment._set_model_graph(self.model, framework="transformers")
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
if isinstance(self.model, PreTrainedModel):
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")

def num_examples(self, dataloader: DataLoader) -> int:
"""
Expand Down Expand Up @@ -679,7 +678,11 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
find_unused_parameters=(
not getattr(model.config, "gradient_checkpointing", False)
if isinstance(model, PreTrainedModel)
else True
),
)
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
Expand Down Expand Up @@ -707,30 +710,28 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

self.global_step = 0
self.epoch = 0
self.total_flos = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0

# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)

epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)

logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.")

tr_loss = torch.tensor(0.0).to(self.args.device)
self.total_flos = self.state.total_flos
logging_loss_scalar = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
logs["total_flos"] = total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
Expand Down Expand Up @@ -1245,11 +1246,9 @@ def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos
self.state.total_flos = self.total_flos

def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
Expand Down Expand Up @@ -1363,13 +1362,6 @@ def prediction_loop(
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)

assert not getattr(
self.model.config, "output_attentions", False
), "The prediction loop does not work with `output_attentions=True`."
assert not getattr(
self.model.config, "output_hidden_states", False
), "The prediction loop does not work with `output_hidden_states=True`."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn’t we want to put these lines inside an if statement? The prediction loop still doesn’t work with these outputs right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope it does now since the functions that detach/concat etc. all work on nested list/tuples of tensors :-)

model = self.model
# multi-gpu eval
if self.args.n_gpu > 1:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class TrainerState:
A class containing the `Trainer` fields that will be saved along the model and optimizer.
"""

total_flos: int = 0
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None

Expand Down