Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,13 +2495,13 @@ def _inner_training_loop(
step = -1
epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = num_examples % args.gradient_accumulation_steps
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
if args.gradient_accumulation_steps == 1:
total_updates -= 1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(
remainder < args.gradient_accumulation_steps
)
Comment on lines +2498 to +2504
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should give the same results no before and after but agree that this is a bit strange to use num_examples for remainder but not for total_updates.

Copy link
Contributor Author

@efsotr efsotr May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When computing the remainder, there was an error where steps_in_epoch was mistakenly written as num_examples. Here, num_examples refers to the size of the dataset, while steps_in_epoch is the number of batches in the dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_examples != steps_in_epoch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when steps_in_epoch is multiple of args.gradient_accumulation_steps, total_updates is incorrectly greater than expected by 1.

Copy link
Member

@SunMarc SunMarc May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah indeed, my bad. Can you share the results of your tests before and after this PR in the description? That would help future readers !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
Expand Down Expand Up @@ -5319,7 +5319,11 @@ def set_initial_training_values(

# Case 2: We have a dataloader length and can extrapolate
if len_dataloader is not None:
num_update_steps_per_epoch = max(len_dataloader // args.gradient_accumulation_steps, 1)
num_update_steps_per_epoch = max(
len_dataloader // args.gradient_accumulation_steps
+ int(len_dataloader % args.gradient_accumulation_steps > 0),
1,
)
Comment on lines +5322 to +5326
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like the only real change no ?

# Case 3: We have a length but are using epochs, we can extrapolate the number of steps
if epoch_based:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
require_torch_fp16,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_non_multi_accelerator,
require_torch_non_multi_gpu,
require_torch_tensorrt_fx,
Expand Down Expand Up @@ -3763,6 +3764,37 @@ def test_num_train_epochs_in_training(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))

@require_torch_multi_gpu
def test_num_batches_in_training_with_gradient_accumulation(self):
with tempfile.TemporaryDirectory() as tmp_dir:
for num_train_epochs in [1, 2]:
for train_len in [123, 120]:
trainer = get_regression_trainer(
train_len=train_len,
per_device_train_batch_size=4,
gradient_accumulation_steps=5,
num_train_epochs=num_train_epochs,
output_dir=tmp_dir,
)

total_batch_samples = []

def wrap_get_batch_samples(fn):
def wrapped_fn(epoch_iterator, num_batches, device):
self.assertGreater(num_batches, 0)
batch_samples, num_items_in_batch = fn(epoch_iterator, num_batches, device)
self.assertEqual(len(batch_samples), num_batches)
total_batch_samples.append(num_batches)
return batch_samples, num_items_in_batch

return wrapped_fn

trainer.get_batch_samples = wrap_get_batch_samples(trainer.get_batch_samples)

trainer.train()

self.assertEqual(len(trainer.get_train_dataloader()) * num_train_epochs, sum(total_batch_samples))

def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down