Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add throughput utilities to Fabric and the Trainer #18848

Merged
merged 33 commits into from
Oct 30, 2023

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Oct 24, 2023

What does this PR do?

Ports https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/speed_monitor.py.

The API is changed, as it nows follows the torchmetrics style of update and compute on Fabric.
On the Trainer, a regular Callback is kept.

Careful consideration of edge cases was added to minimize user errors, for instance, the addition of the _MonotonicWindow class instead of a regular deque.

(Fabric) Throughput example:

from time import perf_counter
import torch
from torch.utils.flop_counter import bmm_flop
from lightning.fabric.utilities import Throughput
from lightning.fabric.utilities.throughput import get_available_flops

torch.inference_mode().__enter__()
device = torch.device("cuda")
B, N = 1024, 1024
x = torch.randn(B, N, N, device=device)
available_flops = get_available_flops(device, torch.float32)
flops_per_batch = bmm_flop(x.shape, x.shape)
print(f"TFLOPs: {flops_per_batch / 1e12} out of {available_flops / 1e12} ({flops_per_batch / available_flops:.3%})")

throughput = Throughput(available_flops=available_flops, window_size=10)
t0 = perf_counter()
for i in range(1, 101):
    # simulate work
    y = x @ x

    torch.cuda.synchronize()  # required or else time won't be correct
    throughput.update(time=perf_counter() - t0, batches=i, samples=i * B, flops=flops_per_batch)
    if i % 10 == 0:
        print(i, throughput.compute())

print(f"{torch.cuda.max_memory_allocated() / 1e9} GB")

(Fabric) ThroughputMonitor example:

from time import time
import torch
from lightning import Fabric
from lightning.fabric.utilities.throughput import measure_flops, ThroughputMonitor

B, N = 1024, 1024


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("x", torch.randn(B, N, N))

    def forward(self):
        # simulate work
        return self.x @ self.x


torch.inference_mode().__enter__()

fabric = Fabric(accelerator="cuda", devices=1, precision="32-true")

with torch.device("meta"):
    meta_model = Model()
    flops_per_batch = measure_flops(meta_model, meta_model.forward)
throughput = ThroughputMonitor(fabric, window_size=10)
print(
    f"TFLOPs: {flops_per_batch / 1e12} out of"
    f" {throughput.available_flops / 1e12} ({flops_per_batch / throughput.available_flops:.3%})"
)

model = Model()
model = fabric.setup(model)

t0 = time()
for i in range(1, 101):
    y = model()

    torch.cuda.synchronize()  # required or else time won't be correct
    throughput.update(time=time() - t0, batches=i, samples=i * B, flops=flops_per_batch)
    if i % 10 == 0:
        print(i, throughput.compute())

print(f"{torch.cuda.max_memory_allocated() / 1e9} GB")

(Trainer) ThroughputMonitor example:

from itertools import count
from unittest.mock import Mock

import torch

from lightning import LightningModule, Trainer
from lightning.fabric.utilities.throughput import measure_flops
from lightning.pytorch.callbacks import ThroughputMonitor


class PrintingLogger(Mock):
    save_dir = "."

    def log_metrics(self, metrics, step):
        print(step, metrics)


B, N = 1024, 1024


class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.register_buffer("x", torch.randn(B, N, N))

    def setup(self, stage: str) -> None:
        with torch.device("meta"):
            model = Model()
            self.flops_per_batch = measure_flops(model, model.forward)
        print(
            f"TFLOPs: {self.flops_per_batch / 1e12} out of"
            f" {throughput.available_flops / 1e12} ({self.flops_per_batch / throughput.available_flops:.3%})"
        )

    def forward(self, _=None):
        # simulate work
        return self.x @ self.x


throughput = ThroughputMonitor(batch_size_fn=lambda _: B, window_size=10)
trainer = Trainer(
    accelerator="cuda",
    devices=1,
    precision="32-true",
    limit_predict_batches=100,
    log_every_n_steps=10,
    callbacks=throughput,
    enable_model_summary=False,
    enable_progress_bar=False,
    logger=PrintingLogger(),
)
model = Model()
trainer.predict(model, count(), return_predictions=False)

print(f"{torch.cuda.max_memory_allocated() / 1e9} GB")

📚 Documentation preview 📚: https://pytorch-lightning--18848.org.readthedocs.build/en/18848/

cc @Borda @awaelchli @carmocca @justusschock

@carmocca carmocca added feature Is an improvement or enhancement callback fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Oct 24, 2023
@carmocca carmocca added this to the 2.2 milestone Oct 24, 2023
@carmocca carmocca self-assigned this Oct 24, 2023
@github-actions github-actions bot added the docs Documentation related label Oct 24, 2023
@carmocca carmocca changed the title Add SpeedMonitor and measure_flops Add ThroughputMonitor and measure_flops Oct 24, 2023
@carmocca carmocca marked this pull request as ready for review October 24, 2023 12:20
@github-actions
Copy link
Contributor

github-actions bot commented Oct 24, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, src/lightning/pytorch/callbacks/__init__.py, src/lightning/pytorch/callbacks/throughput_monitor.py, src/lightning/pytorch/core/saving.py, src/lightning/pytorch/trainer/connectors/checkpoint_connector.py, src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py, src/lightning/pytorch/trainer/connectors/signal_connector.py, src/lightning/pytorch/utilities/__init__.py, src/lightning/pytorch/utilities/imports.py, tests/tests_pytorch/callbacks/test_throughput_monitor.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/models/test_hparams.py, tests/tests_pytorch/utilities/test_imports.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/pytorch/callbacks/__init__.py, src/lightning/pytorch/callbacks/throughput_monitor.py, src/lightning/pytorch/core/saving.py, src/lightning/pytorch/trainer/connectors/checkpoint_connector.py, src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py, src/lightning/pytorch/trainer/connectors/signal_connector.py, src/lightning/pytorch/utilities/__init__.py, src/lightning/pytorch/utilities/imports.py, tests/tests_pytorch/callbacks/test_throughput_monitor.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/models/test_hparams.py, tests/tests_pytorch/utilities/test_imports.py, src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, src/lightning/pytorch/callbacks/__init__.py, src/lightning/pytorch/callbacks/throughput_monitor.py, src/lightning/pytorch/core/saving.py, src/lightning/pytorch/trainer/connectors/checkpoint_connector.py, src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py, src/lightning/pytorch/trainer/connectors/signal_connector.py, src/lightning/pytorch/utilities/__init__.py, src/lightning/pytorch/utilities/imports.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, docs/source-fabric/api/utilities.rst, docs/source-fabric/conf.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/callbacks/__init__.py, src/lightning/pytorch/callbacks/throughput_monitor.py, src/lightning/pytorch/core/saving.py, src/lightning/pytorch/trainer/connectors/checkpoint_connector.py, src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py, src/lightning/pytorch/trainer/connectors/signal_connector.py, src/lightning/pytorch/utilities/__init__.py, src/lightning/pytorch/utilities/imports.py, docs/source-pytorch/api_references.rst, docs/source-pytorch/conf.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, tests/tests_fabric/utilities/test_throughput.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, tests/tests_fabric/utilities/test_throughput.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, src/lightning/pytorch/callbacks/__init__.py, src/lightning/pytorch/callbacks/throughput_monitor.py, src/lightning/pytorch/core/saving.py, src/lightning/pytorch/trainer/connectors/checkpoint_connector.py, src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py, src/lightning/pytorch/trainer/connectors/signal_connector.py, src/lightning/pytorch/utilities/__init__.py, src/lightning/pytorch/utilities/imports.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/accelerators/cuda.py, src/lightning/fabric/fabric.py, src/lightning/fabric/utilities/__init__.py, src/lightning/fabric/utilities/imports.py, src/lightning/fabric/utilities/rank_zero.py, src/lightning/fabric/utilities/throughput.py, src/lightning/pytorch/callbacks/__init__.py, src/lightning/pytorch/callbacks/throughput_monitor.py, src/lightning/pytorch/core/saving.py, src/lightning/pytorch/trainer/connectors/checkpoint_connector.py, src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py, src/lightning/pytorch/trainer/connectors/signal_connector.py, src/lightning/pytorch/utilities/__init__.py, src/lightning/pytorch/utilities/imports.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@carmocca carmocca marked this pull request as ready for review October 25, 2023 19:46
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The feature is for sure very valuable, thank you for adding it. My only gripe is with the way things get logged in Fabric. I see that it is not flexible enough currently and that I would probably struggle using the monitor together with e.g. the wandb logger if the stepping is coded into the monitor itself.

src/lightning/fabric/utilities/throughput_monitor.py Outdated Show resolved Hide resolved
src/lightning/fabric/utilities/throughput_monitor.py Outdated Show resolved Hide resolved
@carmocca carmocca changed the title Add ThroughputMonitor and measure_flops Add throughput utilities to Fabric and the Trainer Oct 26, 2023
@mergify mergify bot added the ready PRs ready to be merged label Oct 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback docs Documentation related fabric lightning.fabric.Fabric feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants