Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed still changes metric states from fp32 to fp16 #1561

Closed
FarzanT opened this issue Feb 27, 2023 · 7 comments · Fixed by #1583
Closed

DeepSpeed still changes metric states from fp32 to fp16 #1561

FarzanT opened this issue Feb 27, 2023 · 7 comments · Fixed by #1583
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@FarzanT
Copy link
Contributor

FarzanT commented Feb 27, 2023

🐛 Bug

Following #484, the PR #493 introduced set_dtype() to prevent .half() calls to change the precision of metric states. However, at least for PearsonCorrCoef, DeepSpeed still somehow modifies the dtype. During initialization the metric states have the default dtype of torch.float32. However, this changes as soon as DeepSpeed is initialized (refer to the code example below).

I tried this with MeanAbsoluteError as well, and interestingly, out of sum_abs_error and total metric states, only sum_abs_error changes from torch.float32 to torch.float16. The total metric state remains as torch.int64. So DeepSpeed is probably only converting floats and not ints.

This is especially problematic for epoch level metrics such as PearsonCorrCoef, since the numbers they hold can easily become larger the the 65,536 maximum allowed with fp16 precision.

To Reproduce

Refer to the minimal code sample.

Code sample
import os
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torchmetrics import PearsonCorrCoef, MeanAbsoluteError


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class PlDataModule(LightningDataModule):
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 32)
        self.metric = PearsonCorrCoef()
        self.mae = MeanAbsoluteError()
        print("Before DeepSpeed initialization")
        print("self.metric.mean_x", self.metric.mean_x)
        print("self.metric.mean_x.dtype", self.metric.mean_x.dtype)
        print("self.metric.mean_y", self.metric.mean_y)
        print("self.metric.dtype", self.metric.mean_y.dtype)
        print("self.metric.var_x", self.metric.var_x)
        print("self.metric.var_x.dtype", self.metric.var_x.dtype)
        print("self.metric.var_y", self.metric.var_y)
        print("self.metric.var_y.dtype", self.metric.var_y.dtype)
        print("self.metric.corr_xy", self.metric.corr_xy)
        print("self.metric.corr_xy.dtype", self.metric.corr_xy.dtype)
        print("self.metric.n_total", self.metric.n_total)

        print("self.mae.sum_abs_error.dtype", self.mae.sum_abs_error.dtype)
        print("self.mae.total.dtype", self.mae.total.dtype)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        pred = self.forward(batch)
        loss = self(batch).sum()
        self.metric.update(torch.flatten(pred), torch.flatten(batch))

        print("After DeepSpeed initialization")
        print("self.metric.mean_x", self.metric.mean_x)
        print("self.metric.mean_x.dtype", self.metric.mean_x.dtype)
        print("self.metric.mean_y", self.metric.mean_y)
        print("self.metric.dtype", self.metric.mean_y.dtype)
        print("self.metric.var_x", self.metric.var_x)
        print("self.metric.var_x.dtype", self.metric.var_x.dtype)
        print("self.metric.var_y", self.metric.var_y)
        print("self.metric.var_y.dtype", self.metric.var_y.dtype)
        print("self.metric.corr_xy", self.metric.corr_xy)
        print("self.metric.corr_xy.dtype", self.metric.corr_xy.dtype)
        print("self.metric.n_total", self.metric.n_total)

        print("self.mae.sum_abs_error.dtype", self.mae.sum_abs_error.dtype)
        print("self.mae.total.dtype", self.mae.total.dtype)

        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir="/scratch/l/lstein/ftaj/scTransformers/scTransPert/checkpoints/temp/",
        # default_root_dir=os.getcwd(),
        limit_train_batches=10,
        num_sanity_val_steps=0,
        max_epochs=1,
        strategy="deepspeed_stage_1",
        accelerator="gpu",
        precision=16
    )
    trainer.fit(model, datamodule=PlDataModule())


run()
Output on my setup
(autopath) [ftaj@mist021 autopath]$ python src/deepspeed_issue.py
Before DeepSpeed initialization
self.metric.mean_x tensor([0.])
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.])
self.metric.dtype torch.float32
self.metric.var_x tensor([0.])
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([0.])
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([0.])
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([0.])
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
Enabling DeepSpeed FP16.
/home/l/lstein/ftaj/.conda/envs/autopath/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:613: UserWarning: Checkpoint directory /scratch/l/lstein/ftaj/scTransformers/scTransPert/checkpoints/temp/lightning_logs/version_357072/checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2023-02-26 21:48:53,595] [WARNING] [engine.py:1223:_do_optimizer_sanity_check] **** You are using ZeRO with an untested optimizer, proceed with caution *****
Rank: 0 partition count [1] and sizes[(1056, False)]

  | Name   | Type              | Params
---------------------------------------------
0 | layer  | Linear            | 1.1 K
1 | metric | PearsonCorrCoef   | 0
2 | mae    | MeanAbsoluteError | 0
---------------------------------------------
1.1 K     Trainable params
0         Non-trainable params
1.1 K     Total params
0.002     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.
/home/l/lstein/ftaj/.conda/envs/autopath/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 128 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/l/lstein/ftaj/.conda/envs/autopath/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 0:   0%|                                                                                                                                                                                                                                           | 0/10 [00:00<?, ?it/s]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.0167], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0883], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([18.6094], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([71.9375], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([4.5664], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([64.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  10%|████████████████████                                                                                                                                                                                     | 1/10 [00:00<00:04,  2.20it/s, loss=-1.07, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([0.0207], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0510], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([34.7812], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([135.1250], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-1.7344], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([128.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  20%|████████████████████████████████████████▍                                                                                                                                                                 | 2/10 [00:00<00:01,  4.31it/s, loss=1.33, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([0.0435], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0620], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([55.9375], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([202.5000], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-3.5781], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([192.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  30%|████████████████████████████████████████████████████████████▌                                                                                                                                             | 3/10 [00:00<00:01,  6.35it/s, loss=2.78, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([0.0287], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0602], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([73.5625], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([258.5000], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-11.3281], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([256.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  40%|████████████████████████████████████████████████████████████████████████████████▊                                                                                                                         | 4/10 [00:00<00:00,  8.29it/s, loss=1.83, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.0767], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0827], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([107.2500], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([315.2500], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-8.4688], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([320.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  50%|████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                    | 5/10 [00:00<00:00, 10.15it/s, loss=-4.91, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.0933], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0465], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([132.8750], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([394.5000], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-4.7578], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([384.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                | 6/10 [00:00<00:00, 11.97it/s, loss=-5.97, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.3110], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0288], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([371.5000], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([456.2500], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-15.1953], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([448.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                            | 7/10 [00:00<00:00, 13.70it/s, loss=-19.9, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.1224], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0084], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([673.], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([510.5000], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([8.4297], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([512.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                        | 8/10 [00:00<00:00, 15.37it/s, loss=-7.84, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.1087], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([-0.0055], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([739.], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([568.], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([1.3281], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([576.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0:  90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                    | 9/10 [00:00<00:00, 16.98it/s, loss=-6.96, v_num=357072]
After DeepSpeed initialization
self.metric.mean_x tensor([-0.0939], device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor([0.0134], device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([862.5000], device='cuda:0', dtype=torch.float16)
self.metric.var_x.dtype torch.float16
self.metric.var_y tensor([642.], device='cuda:0', dtype=torch.float16)
self.metric.var_y.dtype torch.float16
self.metric.corr_xy tensor([-19.6406], device='cuda:0', dtype=torch.float16)
self.metric.corr_xy.dtype torch.float16
self.metric.n_total tensor([640.], device='cuda:0', dtype=torch.float16)
self.mae.sum_abs_error.dtype torch.float16
self.mae.total.dtype torch.int64
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 18.49it/s, loss=-6.01, v_num=357072]/home/l/lstein/ftaj/.conda/envs/autopath/lib/python3.8/site-packages/torch/nn/modules/module.py:1428: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 17.59it/s, loss=-6.01, v_num=357072]

Expected behavior

DeepSpeed should not affect dtype of metric states, that is, they should remain in torch.float32 even if DeepSpeed with precision=16 is used.

Environment

collect_env results
* CUDA:
        - GPU:
                - Tesla V100-SXM2-32GB
        - available:         True
        - version:           11.4
* Lightning:
        - lightning-cloud:   0.5.26
        - lightning-utilities: 0.6.0.post0
        - pytorch-lightning: 1.9.1
        - reformer-pytorch:  1.4.4
        - torch:             1.13.0

        - av:                10.0.0
        - axial-positional-embedding: 0.2.1
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.1
        - bleach:            4.1.0
        - blessed:           1.20.0
        - bottleneck:        1.3.5
        - brotlipy:          0.7.0
        - cached-property:   1.5.2
        - captum:            0.5.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 2.0.4
        - click:             8.0.4
        - cmake:             3.25.0
        - croniter:          1.3.8
        - cryptography:      38.0.1
        - cycler:            0.11.0
        - dateutils:         0.6.12
        - debugpy:           1.5.1
        - decorator:         5.1.1
        - deepdiff:          6.2.3
        - deepspeed:         0.8.0
        - defusedxml:        0.7.1
        - dnspython:         2.3.0
        - docker-pycreds:    0.4.0
        - einops:            0.4.1
        - email-validator:   1.3.1
        - entrypoints:       0.4
        - executing:         0.8.3
        - fastapi:           0.88.0
        - fastjsonschema:    2.16.2
        - filelock:          3.8.2
        - fire:              0.4.0
        - frozenlist:        1.3.3
        - fsspec:            2022.11.0
        - future:            0.18.2
        - gitdb:             4.0.9
        - gitpython:         3.1.29
        - h11:               0.14.0
        - h5py:              3.6.0
        - hjson:             3.1.0
        - httpcore:          0.16.3
        - httptools:         0.5.0
        - httpx:             0.23.3
        - idna:              3.4
        - importlib-metadata: 5.0.0
        - importlib-resources: 5.2.0
        - inquirer:          3.1.2
        - ipykernel:         6.15.2
        - ipython:           8.6.0
        - ipython-genutils:  0.2.0
        - itsdangerous:      2.1.2
        - jedi:              0.18.1
        - jinja2:            3.1.2
        - joblib:            1.2.0
        - jsonschema:        4.16.0
        - jupyter-client:    7.3.4
        - jupyter-core:      4.11.2

        - readchar:          4.0.3
        - reformer-pytorch:  1.4.4
        - requests:          2.28.1
        - rfc3986:           1.5.0
        - rich:              13.3.1
        - routing-transformer: 1.6.1
        - scanpy:            1.9.1
        - scikit-learn:      1.1.2
        - scipy:             1.8.1
        - seaborn:           0.12.1
        - send2trash:        1.8.0
        - sentencepiece:     0.1.96
        - sentry-sdk:        1.10.1
        - session-info:      1.0.0
        - setproctitle:      1.3.2
        - setuptools:        65.5.0
        - shortuuid:         1.0.10
        - six:               1.16.0
        - smmap:             5.0.0
        - sniffio:           1.2.0
        - soupsieve:         2.3.2.post1
        - stack-data:        0.2.0
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - statsmodels:       0.11.1
        - stdlib-list:       0.8.0
        - tabulate:          0.8.10
        - tensorboardx:      2.5.1
        - termcolor:         2.1.1
        - terminado:         0.13.1
        - threadpoolctl:     3.1.0
        - tinycss2:          1.2.1
        - torch:             1.13.0
        - torchmetrics:      0.11.2
        - torchtext:         0.14.0a0+f8d65dd
        - torchvision:       0.14.0
        - tornado:           6.1
        - tqdm:              4.64.1
        - traitlets:         5.9.0
        - typing-extensions: 4.3.0
        - typing-inspect:    0.8.0
        - ujson:             5.7.0
        - umap-learn:        0.5.2
        - urllib3:           1.26.12
        - uvicorn:           0.20.0
        - uvloop:            0.17.0
        - wandb:             0.13.5
        - watchfiles:        0.18.1
        - wcwidth:           0.2.5
        - webencodings:      0.5.1
        - websocket-client:  0.58.0
        - websockets:        10.4
        - wheel:             0.37.1
        - x-transformers:    1.5.1
        - xformers:          0.0.13
        - yarl:              1.8.2
        - zipp:              3.10.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         ppc64le
        - python:            3.8.13
        - version:           #1 SMP Thu Nov 17 08:57:02 EST 2022

Additional context

@FarzanT FarzanT added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 27, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

@justusschock do you have experience with the internals of DeepSpeed? Is this something we can get around or is this a limitation?

@justusschock
Copy link
Member

@SkafteNicki I don't have too much experience with deepspeed internals, but I'll have a look why/where this happens.

@justusschock justusschock self-assigned this Feb 27, 2023
@FarzanT
Copy link
Contributor Author

FarzanT commented Feb 27, 2023

@justusschock Hi! One thing I noticed is that the PearsonCorrCoef shows up in the model summary, even though I'm not using this metric for backpropagation, only during validation. Since PearsonCorrCoef is differentiable, perhaps it gets added to the computation graph and DeepSpeed finds another way to convert it? Just a guess.

@justusschock
Copy link
Member

Hey @FarzanT , PersonCorrCoef shows up in the model summary because it is a subclass of Metric (which in turn is a subclass of torch.nn.Module) and model summary lists all nn.Modules attached to the lightning module. This should, however, not result in devices changes for metrics as this is explicitly guarded.

@FarzanT
Copy link
Contributor Author

FarzanT commented Feb 28, 2023

@justusschock
I realized calling .half() on a lightning module converts the metric states. This can be seen with the following code example:

import torch

from pytorch_lightning import LightningModule
from torchmetrics import PearsonCorrCoef, MeanAbsoluteError


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 32)
        self.pearson = PearsonCorrCoef()
        self.mae = MeanAbsoluteError()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        pred = self.forward(batch)
        loss = self(batch).sum()
        self.metric.update(torch.flatten(pred), torch.flatten(batch))

        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = BoringModel()
print(model.mae.sum_abs_error.dtype)
print(model.pearson.mean_x.dtype)
model = model.half()
print(model.mae.sum_abs_error.dtype)
print(model.pearson.mean_x.dtype)
model = model.float()
print(model.mae.sum_abs_error.dtype)
print(model.pearson.mean_x.dtype)

Output:

torch.float32
torch.float32
torch.float16
torch.float16
torch.float32
torch.float32

Any pointers on how to protect the metrics states from this?

@FarzanT
Copy link
Contributor Author

FarzanT commented Mar 2, 2023

@justusschock @SkafteNicki I think I figured out the source of the problem:
https://github.com/Lightning-AI/metrics/blob/bc722a7e487628a8be83bd8e4eaddc25c18698b2/src/torchmetrics/metric.py#L697-L730
The _apply() function in metric.py was specifically made to apply .to(), .cuda() etc calls on the metric states. I guess the same goes for .half(). So, when you call .half() on a Lightning Module, .half() is applied directly on the metric states, circumventing the 'guards' proposed in #493.
Should we put a check in the _apply() function to ignore functions that change dtype, like half() and float() etc.?

May I open a pull request?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants