Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchmetrics 0.4+ broke high-precision metric states #484

Closed
leezu opened this issue Aug 27, 2021 · 6 comments · Fixed by #493
Closed

torchmetrics 0.4+ broke high-precision metric states #484

leezu opened this issue Aug 27, 2021 · 6 comments · Fixed by #493
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@leezu
Copy link

leezu commented Aug 27, 2021

🐛 Bug

Metric states with custom dtype are no longer handled correctly in torchmetrics 0.4 and 0.5.
Consider the following metric, which works fine in 0.3

class SimpleMetric(torchmetrics.Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.add_state(
            "state", default=torch.tensor([0], dtype=torch.float64), dist_reduce_fx="sum"
        )

    def update(self, x: torch.Tensor):
        self.state += x.double()

    def compute(self):
        print(self.state)
        return self.state

Code sample to reproduce

import os
import torch

import torchmetrics
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class SimpleMetric(torchmetrics.Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.add_state(
            "state", default=torch.tensor([0], dtype=torch.float64), dist_reduce_fx="sum"
        )

    def update(self, x: torch.Tensor):
        self.state += x.double()

    def compute(self):
        print(self.state)
        return self.state



class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class PlDataModule(LightningDataModule):
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.metric = SimpleMetric()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.metric(loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=10,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        accelerator="deepspeed",
        precision=16,
        gpus=1,
        progress_bar_refresh_rate=0,
    )
    trainer.fit(model, datamodule=PlDataModule())


if __name__ == "__main__":
    run()

Actual behavior

This is the behavior with pytorch-lightning 1.3 or higher + torchmetrics 0.4 or higher.
Note the state is cast to fp16 by deepspeed automatically Lightning-AI/pytorch-lightning#7476.

% python3 test.py
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All DDP processes registered. Starting ddp with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Enabling DeepSpeed FP16.
You have not specified an optimizer or scheduler within the DeepSpeed config.Using `configure_optimizers` to define optimizer and scheduler.
[2021-08-27 22:30:13,586] [WARNING] [engine.py:726:_configure_optimizer] **** You are using ZeRO with an untested optimizer, proceed with caution *****
Using /home/ubuntu/.cache/torch_extensions/py38_cu114 as PyTorch extensions root...
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py38_cu114/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module utils...
Time to load utils op: 0.08644556999206543 seconds
Using /home/ubuntu/.cache/torch_extensions/py38_cu114 as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.0003902912139892578 seconds
/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:102: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
tensor([-0.2761], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-0.8501], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-1.5273], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-0.8506], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-0.7441], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-0.3738], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-5.4336], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-2.0137], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-2.3652], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
tensor([-4.8438], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)

Expected behavior

This is the behavior with pytorch-lightning 1.3 + torchmetrics 0.3 or higher.
Note the state remains fp64.

% python3 test.py
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All DDP processes registered. Starting ddp with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Enabling DeepSpeed FP16.
You have not specified an optimizer or scheduler within the DeepSpeed config.Using `configure_optimizers` to define optimizer and scheduler.
[2021-08-27 22:31:11,341] [WARNING] [engine.py:726:_configure_optimizer] **** You are using ZeRO with an untested optimizer, proceed with caution *****
Using /home/ubuntu/.cache/torch_extensions/py38_cu114 as PyTorch extensions root...
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py38_cu114/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module utils...
Time to load utils op: 0.08263587951660156 seconds
Using /home/ubuntu/.cache/torch_extensions/py38_cu114 as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.0003662109375 seconds
/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:102: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
tensor([-0.8662], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([-0.8799], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([1.5469], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([-0.6304], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([1.0391], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([0.0771], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([3.7480], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([-0.2329], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([-7.7930], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([-5.8711], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)

Additional context

Reverting

-        # Also apply fn to metric states
-        for key in this._defaults.keys():
+        # Also apply fn to metric states and defaults
+        for key, value in this._defaults.items():
+            if isinstance(value, Tensor):
+                this._defaults[key] = fn(value)
+            elif isinstance(value, Sequence):
+                this._defaults[key] = [fn(v) for v in value]

from cda5dbd fixes the issue.

cc @SkafteNicki

@leezu leezu added bug / fix Something isn't working help wanted Extra attention is needed labels Aug 27, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

SkafteNicki commented Aug 28, 2021

Hi @leezu,
So this also came up in a recent PR #462.
It is not as much as a bug fix than this is a feature request. IMO the user should have a way of controlling which dtype they want to use with their metric, with the obvious that this should happen whenever .half(), .double() ect is called. That said, I understand that there is a need for disabling this behaviour when training with mixed precision.

This is definitely something we need to include in the API redesign going on in PR #479

@leezu leezu mentioned this issue Aug 28, 2021
4 tasks
@leezu
Copy link
Author

leezu commented Aug 28, 2021

Thank you @SkafteNicki. I'd agree with you on "feature request" if this would have never worked. But it's a bug in that cda5dbd enabled overwriting of user defaults if a fn is applied recursively to a Module :) This happens for example in mixed precision training, see https://github.com/microsoft/DeepSpeed/blob/5b393f1555143968fbac78ecbf3c00f65baf7d78/deepspeed/runtime/engine.py#L563-L570

IMO the user should have a way of controlling which dtype they want to use with their metric, with the obvious that this should happen whenever .half(), .double() ect is called.

As the Metric is a Module and it's an attribute of the the actual model, half(), double() etc. could be invoked on the parent Module for mixed precision training, even though the metric should not be transitioned to half().

I'm looking forward to the redesign in #479 Please ping me if I can help test any changes. Thank you!

leezu added a commit to leezu/metrics that referenced this issue Aug 29, 2021
@Borda
Copy link
Member

Borda commented Aug 30, 2021

@leezu mind check today RC v0.5.1

@leezu
Copy link
Author

leezu commented Aug 30, 2021

The issue persists on v0.5.1rc0 @Borda. The reason is that deepspeed calls .half() (https://github.com/microsoft/DeepSpeed/blob/5b393f1555143968fbac78ecbf3c00f65baf7d78/deepspeed/runtime/engine.py#L563-L570) and as of cda5dbd this is applied to the user defaults. Ie. without cda5dbd, when the defaults are re-applied after the half() call, the original dtype is restored.

@SkafteNicki SkafteNicki mentioned this issue Sep 2, 2021
4 tasks
@SkafteNicki
Copy link
Member

@leezu could you please try with PR #493. Calling half, double, float should no longer have effect on the metric states and they should be whatever you initialize them as. Only using the new metric set_dtype should affect it :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants