Skip to content

Commit

Permalink
Merge pull request #43 from krasserm/wip-pytorch-2
Browse files Browse the repository at this point in the history
Upgrade to PyTorch 2.0 and PyTorch Lightning 2.0
krasserm authored Apr 6, 2023

Verified

This commit was signed with the committer’s verified signature.
miscco Michael Schellenberger Costa
2 parents 9f49d0b + 15f0e4d commit 737a766
Showing 5 changed files with 1,051 additions and 836 deletions.
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -5,6 +5,6 @@ channels:
dependencies:
- python=3.9
- pytorch-cuda=11.7
- pytorch=1.13
- torchvision=0.14
- pytorch=2.0
- torchvision=0.15
- pip>=22
1 change: 0 additions & 1 deletion examples/training/clm/train_fsdp.sh
Original file line number Diff line number Diff line change
@@ -30,7 +30,6 @@ python -m perceiver.scripts.text.clm_fsdp fit \
--trainer.precision=bf16 \
--trainer.max_steps=50000 \
--trainer.accumulate_grad_batches=1 \
--trainer.track_grad_norm=2 \
--trainer.check_val_every_n_epoch=null \
--trainer.val_check_interval=500 \
--trainer.limit_val_batches=20 \
8 changes: 5 additions & 3 deletions perceiver/scripts/text/clm_fsdp.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@

import torch
from pytorch_lightning.cli import LightningArgumentParser, LRSchedulerCallable, OptimizerCallable
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy, StrategyRegistry
from pytorch_lightning.strategies import FSDPStrategy, StrategyRegistry
from pytorch_lightning.utilities import grad_norm
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from perceiver.model.core import CrossAttentionLayer, SelfAttentionLayer
@@ -27,7 +28,7 @@

StrategyRegistry.register(
name="fsdp_perceiver_ar",
strategy=DDPFullyShardedNativeStrategy,
strategy=FSDPStrategy,
description="FSDP strategy optimized for Perceiver AR models",
activation_checkpointing=[CrossAttentionLayer, SelfAttentionLayer],
auto_wrap_policy=policy,
@@ -60,9 +61,10 @@ def configure_optimizers(self):
"lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1},
}

def on_before_optimizer_step(self, optimizer, optimizer_idx):
def on_before_optimizer_step(self, optimizer):
if self.hparams.max_grad_norm is not None:
self.trainer.model.clip_grad_norm_(self.hparams.max_grad_norm)
self.log_dict(grad_norm(self, norm_type=2))


class CausalLanguageModelCLI(CLI):
1,866 changes: 1,040 additions & 826 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -32,8 +32,8 @@ include = ["docs"]

[tool.poetry.dependencies]
python = "^3.8,<3.11"
pytorch-lightning = "^1.7"
torch = "^1.13"
pytorch-lightning = "^2.0"
torch = "^2.0"
fairscale = "^0.4"
torchmetrics = "^0.9"
torch-optimizer = "^0.3"
@@ -43,9 +43,9 @@ cchardet = "^2.1"
datasets = {version = "^2.4", optional = true}
tokenizers = {version = "^0.12", optional = true}
transformers = {version = "^4.21", optional = true}
torchvision = {version = "^0.14", optional = true}
torchvision = {version = "^0.15", optional = true}
opencv-python = {version = "^4.6.0.66", optional = true}
jsonargparse = {extras = ["signatures"], version = "^4.12"}
jsonargparse = {extras = ["signatures"], version = "^4.18"}
fsspec = {extras = ["s3"], version = "*"}

[tool.poetry.group.dev.dependencies]

0 comments on commit 737a766

Please sign in to comment.