From 6720364a553a608744f4248103affe89d8fdbf58 Mon Sep 17 00:00:00 2001 From: kibitzing Date: Tue, 24 Sep 2024 12:59:04 +0900 Subject: [PATCH 1/7] replace total_batched_samples with step while counting grad accum step --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e0a49ee5795e..a3a9372d5030 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2377,7 +2377,7 @@ def _inner_training_loop( ) if ( - total_batched_samples % args.gradient_accumulation_steps == 0 + (step + 1) % args.gradient_accumulation_steps == 0 or # last step in epoch but step is always smaller than gradient_accumulation_steps is_last_step_and_steps_less_than_grad_acc From 59273e69632eaef5af50d2df633c6158c74766cc Mon Sep 17 00:00:00 2001 From: kibitzing Date: Tue, 24 Sep 2024 13:00:09 +0900 Subject: [PATCH 2/7] remove unused variable --- src/transformers/trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a3a9372d5030..f3f4a8114f22 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2282,7 +2282,6 @@ def _inner_training_loop( if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader if hasattr(epoch_iterator, "set_epoch"): @@ -2312,8 +2311,6 @@ def _inner_training_loop( step = -1 for step, inputs in enumerate(epoch_iterator): - total_batched_samples += 1 - if self.args.include_num_input_tokens_seen: main_input_name = getattr(self.model, "main_input_name", "input_ids") if main_input_name not in inputs: From 81215c0cf2b2a77dbfba868c5181089329dffc53 Mon Sep 17 00:00:00 2001 From: kibitzing Date: Thu, 26 Sep 2024 22:03:05 +0900 Subject: [PATCH 3/7] simplify condition for update step --- src/transformers/trainer.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f3f4a8114f22..03c850994716 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2369,19 +2369,11 @@ def _inner_training_loop( self.current_flos += float(self.floating_point_ops(inputs)) - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) + is_last_step = (step + 1) == steps_in_epoch - if ( - (step + 1) % args.gradient_accumulation_steps == 0 - or - # last step in epoch but step is always smaller than gradient_accumulation_steps - is_last_step_and_steps_less_than_grad_acc - ): - # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered - # in accelerate. So, explicitly enable sync gradients to True in that case. - if is_last_step_and_steps_less_than_grad_acc: + if ((step + 1) % args.gradient_accumulation_steps == 0 or is_last_step): + # `is_last_step` case is not covered in accelerate, explicitly enable sync gradients to True. + if is_last_step: self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping From e4cc360955bbb8aef728161121c063d1e65c37dd Mon Sep 17 00:00:00 2001 From: kibitzing Date: Thu, 26 Sep 2024 22:49:37 +0900 Subject: [PATCH 4/7] fix format by ruff --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 03c850994716..cd75dd62d15d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2371,7 +2371,7 @@ def _inner_training_loop( is_last_step = (step + 1) == steps_in_epoch - if ((step + 1) % args.gradient_accumulation_steps == 0 or is_last_step): + if (step + 1) % args.gradient_accumulation_steps == 0 or is_last_step: # `is_last_step` case is not covered in accelerate, explicitly enable sync gradients to True. if is_last_step: self.accelerator.gradient_state._set_sync_gradients(True) From 8348ff4e38df764dfe17735582777f08fd08e376 Mon Sep 17 00:00:00 2001 From: kibitzing Date: Mon, 30 Sep 2024 15:10:01 +0000 Subject: [PATCH 5/7] simplify update step condition using accelerator.sync_gradients --- src/transformers/trainer.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cd75dd62d15d..c42399ae20d9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2369,13 +2369,7 @@ def _inner_training_loop( self.current_flos += float(self.floating_point_ops(inputs)) - is_last_step = (step + 1) == steps_in_epoch - - if (step + 1) % args.gradient_accumulation_steps == 0 or is_last_step: - # `is_last_step` case is not covered in accelerate, explicitly enable sync gradients to True. - if is_last_step: - self.accelerator.gradient_state._set_sync_gradients(True) - + if self.accelerator.sync_gradients: # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping @@ -4775,8 +4769,6 @@ def create_accelerator_and_postprocess(self): # take the gradient_accumulation_steps setting from TrainingArguments. grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps - grad_acc_kwargs["sync_with_dataloader"] = False - gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) accelerator_config = self.args.accelerator_config.to_dict() From ca715e78cce7bc1bc9377ff34c3ccaf3dadfc0ae Mon Sep 17 00:00:00 2001 From: kibitzing Date: Fri, 25 Oct 2024 10:58:08 +0000 Subject: [PATCH 6/7] simplify update condition using do_sync_step --- src/transformers/trainer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7ffc3d38c717..e875a7c868d6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2425,12 +2425,8 @@ def _inner_training_loop( batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) for inputs in batch_samples: step += 1 - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) - do_sync_step = is_last_step_and_steps_less_than_grad_acc or ( - step % args.gradient_accumulation_steps == 0 - ) + print(step, (step+1) % args.gradient_accumulation_steps) + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients if not do_sync_step: self.accelerator.gradient_state._set_sync_gradients(False) @@ -4908,6 +4904,7 @@ def create_accelerator_and_postprocess(self): grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) accelerator_config = self.args.accelerator_config.to_dict() From 8f2102634c82be366ed1fd863d1bddb3fa9aaa13 Mon Sep 17 00:00:00 2001 From: kibitzing Date: Fri, 25 Oct 2024 11:20:09 +0000 Subject: [PATCH 7/7] remove print for test --- src/transformers/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e875a7c868d6..1e21d63d3af5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2425,7 +2425,6 @@ def _inner_training_loop( batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) for inputs in batch_samples: step += 1 - print(step, (step+1) % args.gradient_accumulation_steps) do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients if not do_sync_step: