Skip to content

Commit

Permalink
IPU Integration 5/5 (Lightning-AI#7867)
Browse files Browse the repository at this point in the history
* Initial changes

* Add broken example for now

* Fix reference

* Fix format

* Code runs

* Fixes

* Clear up files

* Add tests, helpers, fixes

* Small cleanups

* Refactors based on review

* Swap to special tests

* Add special tests

* Add source

* Cleanups

* Add logic to attach/detach model from devices

* Fixes for tests

* Fixes for tests

* Move earlier

* Cleanups

* Add check for nvcc

* Add tests, cleanups

* Fix errors

* fix

* Try condition

* Add missing annotation

* Clearer

* Clearer message

* Fix variable

* Cleanups

* Add comment

* CHANGELOG.md

* Add simple selection test

* Remove special=True to see what happens

* Fix test

* Update tests/accelerators/test_ipu.py

Co-authored-by: Kaushik B <[email protected]>

* Convert ipu_cores -> ipus

* Add typing, fail earlier

* simplify precision

* Add test, add helper

* fix accum

* Update pytorch_lightning/plugins/training_type/ipu.py

Co-authored-by: thomas chaton <[email protected]>

* Use stages

* Make sure warning message returned

* thorw error

* Add more tests, use fs

* add comment

* Clean

* Address feedback, add IPU tests

* Fixes

* Fix signature

* Add types

* Remove autoround

* Add docstring

* ipu_cores -> ipus

* Add test, remove unnecessary precision set

* Add optimizer test

* Add precision back with test

* Address code review

* Change to probs

* Move some of the asserts earlier

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
3 people authored and Daniel Dale committed Jun 11, 2021
1 parent 3701825 commit 4c0baa3
Show file tree
Hide file tree
Showing 15 changed files with 1,150 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))


- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))


Expand Down
Empty file.
89 changes: 89 additions & 0 deletions pl_examples/ipu_examples/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


class LitClassifier(pl.LightningModule):

def __init__(
self,
hidden_dim: int = 128,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()

self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
probs = self(x)
# we currently return the accuracy as the validation_step/test_step is run on the IPU devices.
# Outputs from the step functions are sent to the host device, where we calculate the metrics in
# validation_epoch_end and test_epoch_end for the test_step.
acc = self.accuracy(probs, y)
return acc

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
acc = self.accuracy(logits, y)
return acc

def accuracy(self, logits, y):
# currently IPU poptorch doesn't implicit convert bools to tensor
# hence we use an explicit calculation for accuracy here. Once fixed in poptorch
# we can use the accuracy metric.
acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
return acc

def validation_epoch_end(self, outputs) -> None:
# since the training step/validation step and test step are run on the IPU device
# we must log the average loss outside the step functions.
self.log('val_acc', torch.stack(outputs).mean(), prog_bar=True)

def test_epoch_end(self, outputs) -> None:
self.log('test_acc', torch.stack(outputs).mean())

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


if __name__ == '__main__':
dm = MNISTDataModule(batch_size=32)

model = LitClassifier()

trainer = pl.Trainer(max_epochs=2, ipus=8)

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401
35 changes: 35 additions & 0 deletions pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Callable
from typing import Any

from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class IPUAccelerator(Accelerator):
""" Accelerator for IPUs. """

def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
super().setup_optimizers(trainer)

if len(self.optimizers) > 1:
raise MisconfigurationException("IPUs currently only support one optimizer.")

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
# Optimizer step is handled by the IPU accelerator.
lambda_closure()
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -20,6 +21,7 @@
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ipu import IPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
Expand All @@ -41,6 +43,8 @@
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HorovodPlugin",
"IPUPlugin",
"IPUPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
Expand Down
60 changes: 60 additions & 0 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class IPUPrecisionPlugin(PrecisionPlugin):

def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> Tensor:
# IPU internally manages bwd step.
return closure_loss

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None
) -> None:
"""Clips the gradients"""
if clip_val is None:
return

clip_val = float(clip_val)
if clip_val <= 0:
return

raise MisconfigurationException("IPUs currently do not support clipping gradients.")
Loading

0 comments on commit 4c0baa3

Please sign in to comment.