Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colab TPU training stuck at end of epoch if checkpoint_callback=True #3660

Closed
ktrapeznikov opened this issue Sep 25, 2020 · 10 comments · Fixed by #4309
Closed

Colab TPU training stuck at end of epoch if checkpoint_callback=True #3660

ktrapeznikov opened this issue Sep 25, 2020 · 10 comments · Fixed by #4309
Assignees
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working waiting on author Waiting on user action, correction, or update

Comments

@ktrapeznikov
Copy link

I am seeing training stuck at the last step of the 1st epoch when checkpoint_callback=True is enabled. I am passing tpu_cores=8 to the trainer.

If it's False training is still slow for the first few steps and then it speeds up (I guess that this is expected since it takes a few steps for xla to compile stuff).

Here is a link to the colab notebook example.

### Setup Model
class GPT2Tuned(pl.LightningModule):
    def __init__(self, hparams: argparse.Namespace):
        super().__init__()
        self.hparams = hparams
        self.model = AutoModelForCausalLM.from_pretrained(self.hparams.model_name_or_path)

    def forward(self, **inputs):
        return self.model(**inputs)
        
    def _step(self,batch):
        inputs = {"input_ids": batch[0]}
        inputs["labels"] = batch[0]
        outputs = self(**inputs)
        loss = outputs[0]
        return dict(loss=loss)

    def training_step(self, batch, batch_idx):
        out_dict = self._step(batch)
        return dict(loss=out_dict["loss"], log=out_dict)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    def train_dataloader(self):
      # fake data
        dataset = TensorDataset(torch.randint(10000,(10000,128)).long()) 
        return  DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True
        )

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument("--model_name_or_path",type=str)
        parser.add_argument("--batch_size", default=4, type=int)
        parser.add_argument("--learning_rate", default=5e-5, type=float)
        return parser

#### Training 
parser = argparse.ArgumentParser()
parser = GPT2Tuned.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args(["--gpus","0"])

args.__dict__.update(dict(model_name_or_path  = "gpt2-medium",
                          batch_size = 3,
                          precision=16,
                          tpu_cores=8))

model = GPT2Tuned(args)
trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=True)  
trainer.fit(model)
@ktrapeznikov ktrapeznikov changed the title Colab TPU training stuck at end if epoch if checkpoint_callback=True Colab TPU training stuck at end of epoch if checkpoint_callback=True Sep 25, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@awaelchli
Copy link
Contributor

Thanks for the reproduction script.
Could be related #2700
There is a PR for fixing checkpointing on TPU: #2726 (need check if this applies)

@ktrapeznikov
Copy link
Author

Thanks. Seems related. Is there another mechanism to save checkpoints as a temporary fix?

@awaelchli awaelchli added accelerator: tpu Tensor Processing Unit bug Something isn't working labels Sep 28, 2020
@Borda
Copy link
Member

Borda commented Oct 1, 2020

@ktrapeznikov mind checking it now as Google Colab had an internal issue with TPUs in past days...

@Borda Borda added the waiting on author Waiting on user action, correction, or update label Oct 1, 2020
@ktrapeznikov
Copy link
Author

Last time I tried... same issue

@evanatyourservice
Copy link

This would be a nice feature to have since I would rarely do inference on TPU, just training.

@edenlightning edenlightning removed the waiting on author Waiting on user action, correction, or update label Oct 19, 2020
@edenlightning edenlightning added this to the 1.0.3 milestone Oct 19, 2020
@edenlightning
Copy link
Contributor

@lezwon maybe you can help revive the tpu + checkpoint fix?

@lezwon
Copy link
Contributor

lezwon commented Oct 22, 2020

Been working on it. Finding it a bit tricky. I'll raise a PR with the current work I've done.

@sarmientoj24
Copy link

hi, i am experiencing being stuck on 0% at first epoch when using TPU with pytorch-lightning. any possible reasons why?

@lezwon
Copy link
Contributor

lezwon commented Nov 1, 2020

@sarmientoj24 mind share the notebook?

@lezwon lezwon mentioned this issue Nov 1, 2020
8 tasks
@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 10, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 11, 2020
@edenlightning edenlightning removed this from the 1.0.x milestone Nov 13, 2020
@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Nov 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants