Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPU Integration 5/5 #7867

Merged
merged 69 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
f75f445
Initial changes
Feb 22, 2021
a4a60c2
Merge branch 'master' into wip/acc
Mar 24, 2021
dc9744b
Add broken example for now
Mar 24, 2021
931bb74
Fix reference
Apr 7, 2021
9b18baf
Merge branch 'master' into wip/acc
May 11, 2021
c617f02
Fix format
May 11, 2021
522a81f
Code runs
May 11, 2021
0c00360
Fixes
May 26, 2021
30c1370
Merge branch 'master' into wip/acc
May 26, 2021
adbdb2a
Clear up files
May 26, 2021
3e733af
Add tests, helpers, fixes
May 27, 2021
a51f23e
Small cleanups
May 27, 2021
be7de87
Refactors based on review
Jun 1, 2021
83c8a79
Swap to special tests
Jun 1, 2021
a6018e5
Add special tests
Jun 1, 2021
0e71bbe
Add source
Jun 1, 2021
6e38bd1
Cleanups
Jun 1, 2021
526383f
Add logic to attach/detach model from devices
Jun 2, 2021
e18039c
Fixes for tests
Jun 2, 2021
2e43fee
Fixes for tests
Jun 2, 2021
53d31a0
Move earlier
Jun 2, 2021
6241432
Cleanups
Jun 2, 2021
d249a13
Add check for nvcc
Jun 2, 2021
d08cf39
Add tests, cleanups
Jun 2, 2021
7469744
Fix errors
Jun 3, 2021
f474c5b
fix
Jun 3, 2021
e178d5f
Try condition
Jun 3, 2021
c704920
Add missing annotation
Jun 3, 2021
c54a216
Clearer
Jun 3, 2021
2ea1766
Clearer message
Jun 3, 2021
751f0ea
Fix variable
Jun 3, 2021
87e4c8a
Merge branch 'master' into wip/acc
Jun 7, 2021
61d2014
Cleanups
Jun 7, 2021
d76f491
Merge branch 'master' into wip/acc
Jun 7, 2021
62860ff
Add comment
Jun 7, 2021
b5a5032
CHANGELOG.md
Jun 7, 2021
72ed367
Add simple selection test
Jun 7, 2021
88fba4a
Merge branch 'master' into wip/acc
Jun 7, 2021
3fb031d
Remove special=True to see what happens
Jun 7, 2021
515d491
Fix test
Jun 7, 2021
ed16808
Update tests/accelerators/test_ipu.py
Jun 7, 2021
7f50295
Convert ipu_cores -> ipus
Jun 7, 2021
c53cf88
Add typing, fail earlier
Jun 7, 2021
a6dbd8a
simplify precision
Jun 7, 2021
953454b
Add test, add helper
Jun 8, 2021
24829bf
fix accum
Jun 8, 2021
d7d38c5
Update pytorch_lightning/plugins/training_type/ipu.py
Jun 8, 2021
c333e27
Use stages
Jun 8, 2021
9d3741a
Make sure warning message returned
Jun 8, 2021
fd1899a
thorw error
Jun 8, 2021
0727954
Add more tests, use fs
Jun 8, 2021
ce182f7
add comment
Jun 8, 2021
7e81bcd
Clean
Jun 8, 2021
d1788d1
Address feedback, add IPU tests
Jun 9, 2021
08e5338
Fixes
Jun 9, 2021
45dc6a6
Fix signature
Jun 9, 2021
de040c6
Add types
Jun 9, 2021
42d7ab0
Remove autoround
Jun 9, 2021
8ab62c4
Merge branch 'master' into wip/acc
Jun 9, 2021
36f3672
Add docstring
Jun 9, 2021
5f89714
Merge branch 'master' into wip/acc
Jun 10, 2021
d0f98f3
Merge branch 'master' into wip/acc
Jun 10, 2021
f9d61c5
ipu_cores -> ipus
Jun 10, 2021
cf48ff8
Add test, remove unnecessary precision set
Jun 11, 2021
02a75b5
Add optimizer test
Jun 11, 2021
d18fc55
Add precision back with test
Jun 11, 2021
043884a
Address code review
Jun 11, 2021
b249391
Change to probs
Jun 11, 2021
b0dd206
Move some of the asserts earlier
Jun 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))


- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))


Expand Down
Empty file.
89 changes: 89 additions & 0 deletions pl_examples/ipu_examples/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


class LitClassifier(pl.LightningModule):

def __init__(
self,
hidden_dim: int = 128,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()

self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
# we currently return the accuracy as the validation_step/test_step is run on the IPU devices.
# Outputs from the step functions are sent to the host device, where we calculate the metrics in
# validation_epoch_end and test_epoch_end for the test_step.
acc = self.accuracy(logits, y)
return acc
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
acc = self.accuracy(logits, y)
return acc

def accuracy(self, logits, y):
# currently IPU poptorch doesn't implicit convert bools to tensor
# hence we use an explicit calculation for accuracy here. Once fixed in poptorch
# we can use the accuracy metric.
acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
return acc

def validation_epoch_end(self, outputs) -> None:
# since the training step/validation step and test step are run on the IPU device
# we must log the average loss outside the step functions.
self.log('val_acc', torch.stack(outputs).mean(), prog_bar=True)

def test_epoch_end(self, outputs) -> None:
self.log('test_acc', torch.stack(outputs).mean())

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


if __name__ == '__main__':
dm = MNISTDataModule(batch_size=32)

model = LitClassifier()

trainer = pl.Trainer(max_epochs=2, ipus=8)

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401
35 changes: 35 additions & 0 deletions pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Callable
from typing import Any

from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class IPUAccelerator(Accelerator):
""" Accelerator for IPUs. """

def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
super().setup_optimizers(trainer)

if len(self.optimizers) > 1:
raise MisconfigurationException("IPUs currently only support one optimizer.")

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
# Optimizer step is handled by the IPU accelerator.
lambda_closure()
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -20,6 +21,7 @@
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ipu import IPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
Expand All @@ -41,6 +43,8 @@
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HorovodPlugin",
"IPUPlugin",
"IPUPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
Expand Down
60 changes: 60 additions & 0 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class IPUPrecisionPlugin(PrecisionPlugin):

def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> Tensor:
# IPU internally manages bwd step.
return closure_loss

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None
) -> None:
"""Clips the gradients"""
if clip_val is None:
return

clip_val = float(clip_val)
if clip_val <= 0:
return

raise MisconfigurationException("IPUs currently do not support clipping gradients.")
Loading