Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in accelerate_sft_trainer.py: Incorrect calculation of total_steps #426

Closed
rockmagma02 opened this issue Apr 8, 2023 · 1 comment
Closed
Labels
bug Something isn't working

Comments

@rockmagma02
Copy link

🐛 Describe the bug

I found a bug in your accelerate_sft_trainer.py file that I would like to report. Specifically, I noticed an issue with the prepare_learning() function. In the current implementation, self.total_steps is calculated by multiplying the number of epochs (self.config.train.epochs) with the length of the train_dataloader. However, since train_dataloader is assigned before accelerator preparation, it does not reflect the actual number of training steps taken if multiple GPUs are used.

As a result, the total_steps value ends up being larger than the true total number of training steps, leading to training ending prematurely.

To fix this issue, I recommend modifying the prepare_learning() function to calculate self.total_steps using the self.train_dataloader variable instead, which correctly reflects the number of training steps after accelerator preparation. Here is the suggested modification:

    def prepare_learning(self):
        train_dataloader = self.store.create_loader(self.config.train.batch_size)
        eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)

        (
            self.model,
            self.opt,
            self.train_dataloader,
            self.eval_dataloader,
        ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader)

        self.n_updates_per_batch = 1
        self.total_steps = self.config.train.epochs * len(self.train_dataloader)
        self.total_steps = min(self.total_steps, self.config.train.total_steps)

I hope this helps! Let me know if you have any questions or if there's anything else I can assist you with.

Which trlX version are you using?

newest

Additional system and package information

No response

@jon-tow
Copy link
Collaborator

jon-tow commented Apr 13, 2023

Thanks for reporting! Fixed with @reciprocated's #432 patch.

@jon-tow jon-tow closed this as completed Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants