Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] (DataLoader) sanity check fails due to Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) #20456

Open
MathiasBaumgartinger opened this issue Nov 27, 2024 · 6 comments
Labels
3rd party Related to a 3rd-party ver: 2.4.x

Comments

@MathiasBaumgartinger
Copy link

MathiasBaumgartinger commented Nov 27, 2024

Bug description

Hi there! I have previously created my first LightningDataModule. More specifically, a NonGeoDataModule which inherits from there (see torchgeo-fork. Interestingly, when I try to run this module I get RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor. Even more intersting is the fact, that if I override the transfer_batch_to_device like:

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
        print("----------------------------------------")
        for k in batch.keys(): print(k, batch[k][0].get_device())
        print("----------------------------------------")
        
        return batch

I get the output

image 0
mask 0

It happens during the validation step (lightning/pytorch/strategies/strategy.py", line 411).

What version are you seeing the problem on?

v2.4

How to reproduce the bug

def train(
    config: dict, 
    data_dir: str=default_data_dir, 
    root_dir: str=default_root_dir,
    min_epochs: int=1,
    max_epochs: int=25) -> None:
    
    tune_metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    
    module = FL(
        num_workers=config["num_workers"], 
        batch_size=config["batch_size"], 
        patch_size=config["patch_size"],
        val_split_pct=0.25,
        use_toy=True,
        #augs=transforms,
        root=data_dir, 
    )
    task = SemanticSegmentationTask(
        model="unet",
        backbone="resnet50",
        ignore_index=255,
        in_channels=5,#(5+3), #appended indices
        num_classes=13,
        lr=config["lr"],
        patience=config["lr_patience"]
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")
    lr_monitor = LearningRateMonitor(logging_interval="step")
    tune_callback = TuneReportCheckpointCallback(
        {"loss": "val_loss", "accuracy": "val_accuracy"}, on="validation_end"
    )
    logger = TensorBoardLogger(save_dir=root_dir, name="FLAIR2logs")

    trainer = Trainer(
        accelerator=accelerator,
        num_nodes=1,
        callbacks=[checkpoint_callback, lr_monitor, tune_callback],
        log_every_n_steps=1,
        logger=logger,
        min_epochs=1,
        max_epochs=25,
        precision=32,
    )

    trainer.fit(model=task, datamodule=module)

Error messages and logs

Traceback (most recent call last):
  File "//Dev/forks/torchgeo/train_simple.py", line 158, in <module>
    main()
  File "//Dev/forks/torchgeo/train_simple.py", line 154, in main
    train(config)
  File "//Dev/forks/torchgeo/train_simple.py", line 151, in train
    trainer.fit(model=task, datamodule=module)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "//Dev/forks/torchgeo/torchgeo/trainers/segmentation.py", line 251, in validation_step
    y_hat = self(x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//Dev/forks/torchgeo/torchgeo/trainers/base.py", line 81, in forward
    return self.model(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/base/model.py", line 38, in forward
    features = self.encoder(x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/encoders/resnet.py", line 63, in forward
    x = stages[i](x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Environment

Current environment
-----------------------------------------------------------
Python Version: 3.10.4
PyTorch Version: 2.4.1
Cuda is  available version: 12.4
Torch built with CUDA: True
cuDNN Version: 90100
cuDNN Enabled: True
cuDNN available: True
Device: cuda
Accelerator: gpu

lightning                 2.4.0             
lightning-utilities       0.11.9             
pytorch-lightning         2.4.0 

## conda env
name: torchgeo
channels:
  - pytorch
  - nvidia
  - conda-forge
  - defaults
dependencies:
  - python=3.10
  - pytorch-cuda=12.4
  - pytorch=2.4
  - torchgeo=0.6.0
  - tensorboard=2.17
-----------------------------------------------------------

More info

No response

@MathiasBaumgartinger MathiasBaumgartinger added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 27, 2024
@MathiasBaumgartinger
Copy link
Author

MathiasBaumgartinger commented Nov 28, 2024

After some debugging, I found that indeed, the batches were not on the GPU during the different steps. When I add .to(self.device) to the batch['image'] and batch['mask'] accesses (see: https://github.com/microsoft/torchgeo/blob/main/torchgeo/trainers/segmentation.py), the pipeline executes without errors.

AFAIK, this should not be necessary, as a pl.Trainer with a pl.LightningDataModule and pl.LightningModule should guarantee these are on the same device via the transfer_batch_to_device function. And obviously, those are are being called as I do get an output.

EDIT: Tagging @adamjstewart as this might be related to torchgeo

@robmarkcole
Copy link

robmarkcole commented Nov 28, 2024

Hit this too. Not sure how significant it is that the batch is a dict

@lantiga
Copy link
Collaborator

lantiga commented Dec 5, 2024

Hey @MathiasBaumgartinger it definitely isn't needed in standard PyTorch Lightning code.

I'm not sure how dataloading is managed on the torchgeo side in detail though, maybe @adamjstewart you have a hint here?

@lantiga lantiga added 3rd party Related to a 3rd-party and removed bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 5, 2024
@adamjstewart
Copy link
Contributor

The only hacky thing TorchGeo does is to override transfer_batch_to_device like so:

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        # Non-Tensor values cannot be moved to a device
        del batch['crs']
        del batch['bounds']

        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
        return batch

But that's only for GeoDataModule, not NonGeoDataModule. All definitions can be seen here: https://github.com/microsoft/torchgeo/blob/main/torchgeo/datamodules/geo.py

@MathiasBaumgartinger @robmarkcole maybe you can provide more details on how to reproduce this, I haven't personally hit this issue yet.

@robmarkcole
Copy link

I encountered this when using a litdata streaming dataset, serialised via

def fetch_and_serialize_fn(index):
    """Serialize as bytes."""
    sample = dataset[index]  # Fetch the sample
    data = {
        "image_id": sample["image_id"],
        "mask": sample["mask_path"],
        "image": sample["image_path"],
    }
    return data

Which is then handled by a streaming dataset:

class BaseStreamingDataset(StreamingDataset):
    """
    Base class for streaming datasets.

    Args:
        input_dir: Local directory or S3 location of the dataset
        transforms: A transform that takes in an image and returns a transformed version.
        band_indices: List of band indices to read from the dataset.
    """

    def __init__(
        self,
        *args,
        transforms: Optional[Callable] = None,
        band_indices: Optional[List[int]] = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.transforms = transforms
        self.band_indices = band_indices

    def _load_image(self, image_data):
        """Load image data from bytes or NumPy arrays and process it."""
        if isinstance(image_data, bytes):
            # Load image data from bytes using rasterio
            with MemoryFile(image_data) as memfile:
                with memfile.open() as dataset:
                    if self.band_indices is None:
                        image = dataset.read()
                    else:
                        image = dataset.read(
                            indexes=[i + 1 for i in self.band_indices]
                        )  # Rasterio indexes start at 1
        elif isinstance(image_data, np.ndarray):
            # Require channels first
            image_data = move_smallest_dim_to_first(image_data)
            if self.band_indices is not None:
                image = image_data[self.band_indices]
        else:
            raise ValueError(f"Unsupported image data type: {type(image_data)}")

        image = torch.from_numpy(image.astype(np.float32)).float()
        image = self._impute_nans(image)
        return image

    def _load_mask(self, mask_data):
        """Load mask data from bytes or PyTorch tensor."""
        if isinstance(mask_data, bytes):
            with MemoryFile(mask_data) as memfile:
                with memfile.open() as dataset:
                    mask = dataset.read()
        elif isinstance(mask_data, np.ndarray):
            # 2D mask expected
            mask = mask_data
        else:
            raise ValueError(f"Unsupported mask data type: {type(mask_data)}")
        mask = torch.from_numpy(mask).long()
        return mask

    def _impute_nans(self, image):
        if torch.isnan(image).any():
            for channel in range(image.shape[0]):
                mean_val = image[channel].nanmean()
                image[channel] = torch.where(
                    torch.isnan(image[channel]), mean_val, image[channel]
                )
        return image

    def plot(self, sample: dict[str, Tensor]):
        raise NotImplementedError("This method should be implemented by subclasses")


class SegmentationStreamingDataset(BaseStreamingDataset):
    """
    Segmentation dataset with streaming support.

    Args:
        input_dir: Local directory or S3 location of the dataset
        transforms: A transform that takes in an image and returns a transformed version.
        bias_mask: A value to apply to the mask, applied before remapping.
        remap_values: Dictionary to remap values in the segmentation mask.
        binarise_mask: If True, the mask is binarized to 0 and 1.
    """

    def __init__(
        self,
        *args,
        bias_mask: Optional[int] = None,
        remap_values: Optional[Dict[int, int]] = None,
        binarise_mask: Optional[bool] = False,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.bias_mask = bias_mask
        self.remap_values = remap_values
        self.binarise_mask = binarise_mask

    def __getitem__(self, index) -> dict:
        data = super().__getitem__(index)
        sample = {}

        image = self._load_image(data["image"])
        mask = self._load_mask(data["mask"])

        if self.bias_mask:
            mask = mask + self.bias_mask

        if self.remap_values:
            for original_value, new_value in self.remap_values.items():
                mask[mask == original_value] = new_value

        if self.binarise_mask:
            mask[mask > 0] = 1

        if mask.dim() == 3:
            mask = mask.squeeze(0)

        if torch.isnan(mask).any():
            raise ValueError(f"NaNs found in mask for index {index}")

        sample["image"] = image
        sample["mask"] = mask
        if self.transforms is not None:
            sample = self.transforms(sample)
        if "image_id" in data.keys():
            sample["image_id"] = data["image_id"]
        return sample

Via a streaming datamodule:

"""
Use StreamingNonGeoDataModule when working with non-geospatial datasets within 
the TorchGeo framework, leveraging TorchGeo's NonGeoDataModule functionalities.
"""
import re
from typing import Any, Dict, List, Optional, Type

from litdata import StreamingDataLoader
from torch.utils.data import Dataset, default_collate
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.datasets import NonGeoDataset
from upath import UPath
from upath.implementations.cloud import S3Path


class StreamingNonGeoDataModule(NonGeoDataModule):
    """
    The dataloaders are of type StreamingDataLoader compatible with litdata.
    This datamodule requires that root contains train, val and test folders.
    """

    def __init__(
        self,
        dataset_class: Type[NonGeoDataset],
        root: str = "data",
        batch_size: int = 1,
        num_workers: int = 0,
        band_indices: Optional[List[int]] = None,
        bias_mask: Optional[int] = None,
        remap_values: Optional[Dict[int, int]] = None,
        binarise_mask: Optional[bool] = False,
        prefetch_factor: int = 2,
        persistent_workers: bool = False,
        multiprocessing_context: Optional[Any] = None,
    ) -> None:
        """
        Args:
            dataset_class: Class used to instantiate a new dataset.
            root: Path to the input directory or s3 location which contains train, val and test folders.
            batch_size: Size of each mini-batch.
            num_workers: Number of workers for parallel data loading.
            band_indices: List of band indices to read from the dataset.
            bias_mask: A value to apply to the mask, applied before remapping.
            remap_values: Dictionary to remap values in the mask.
            binarise_mask: If True, the mask is binarized to 0 and 1.
            prefetch_factor: Number of batches to prefetch.
            persistent_workers: If True, workers are not shutdown after each epoch.
            multiprocessing_context: Context for multiprocessing
        """
        super().__init__(dataset_class, batch_size, num_workers)
        self.prefetch_factor = prefetch_factor
        self.multiprocessing_context = multiprocessing_context
        self.persistent_workers = persistent_workers

        if not root.endswith("/"):
            root = root + "/"
        if isinstance(UPath(root), S3Path):
            print("root is S3Path")
            self.verify_s3_folder(root)

        self.root = root
        self.band_indices = band_indices
        self.bias_mask = bias_mask
        self.remap_values = remap_values
        self.binarise_mask = binarise_mask
        self.collate_fn = default_collate

    @staticmethod
    def verify_s3_folder(s3_url: str) -> None:
        """Verify the existence of an S3 'folder' (prefix) using upath."""
        try:
            path = UPath(s3_url)
            # Check if the directory is not empty (i.e., it has contents)
            contents = list(path.iterdir())
            if not contents:
                raise FileNotFoundError(f"S3 'folder' not found or is empty: {s3_url}")

        except Exception as e:
            raise RuntimeError(f"Failed to verify S3 'folder': {s3_url}. Error: {e}")

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        if stage == "fit":
            self.train_dataset = self.create_dataset(subset="train")
            self.val_dataset = self.create_dataset(subset="val")
        elif stage == "validate":
            self.val_dataset = self.create_dataset(subset="val")
        elif stage == "test":
            self.test_dataset = self.create_dataset(subset="test")

    def create_dataset(self, subset: str) -> Dataset:
        """Create dataset instance with appropriate parameters.

        Args:
            subset: Either 'train', 'val', or 'test'.

        Returns:
            Dataset instance.
        """
        dataset_kwargs: Dict[str, Any] = {"input_dir": f"{self.root}{subset}"}
        if self.band_indices is not None:
            dataset_kwargs["band_indices"] = self.band_indices
        if self.bias_mask is not None:
            dataset_kwargs["bias_mask"] = self.bias_mask
        if self.remap_values is not None:
            dataset_kwargs["remap_values"] = self.remap_values
        if self.binarise_mask:
            dataset_kwargs["binarise_mask"] = self.binarise_mask

        return self.dataset_class(**dataset_kwargs)

    def _dataloader_factory(self, split: str) -> StreamingDataLoader:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return StreamingDataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            prefetch_factor=self.prefetch_factor,
            persistent_workers=self.persistent_workers,
            multiprocessing_context=self.multiprocessing_context,
        )

The lightning module has this logic:

    def unpack_batch(self, batch):
        if isinstance(batch, tuple):
            return batch  # Assume tuple is in the form (x, y)
        elif isinstance(batch, dict):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/20456
            return batch["image"].to(self.device), batch["mask"].to(self.device)  # Unpack from dict
        elif isinstance(batch, list) and all(
            isinstance(item, torch.Tensor) for item in batch
        ):
            # Assuming the first tensor is images and the second is masks
            return batch[0], batch[1]
        else:
            raise TypeError("Unsupported batch format")

@calebrob6
Copy link
Contributor

calebrob6 commented Dec 19, 2024

Hey all, another torchgeo user here. I hit this as well with 0.7.4.

Here's a minimal reproduceable example:

import kornia.augmentation as K
import lightning.pytorch as pl
import torch
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from torchgeo.trainers import SemanticSegmentationTask

class DummyDataset(Dataset):

    def __init__(self, size=256, length=32):
        super().__init__()
        self.size = size
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        img = torch.randn((4, self.size, self.size)).float()
        mask = torch.randint(0, 2, (self.size, self.size)).long()
        return {
            "image": img,
            "mask": mask,
        }


class DummyDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()

        self.train_aug = K.AugmentationSequential(
            K.RandomRotation(p=0.5, degrees=90),
            data_keys=None,
            keepdim=True
        )

    def setup(self, stage=None):
        self.ds = DummyDataset(size=256, length=32)

    def train_dataloader(self):
        return DataLoader(self.ds, batch_size=4, num_workers=2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.ds, batch_size=4, num_workers=2, shuffle=False)

    def on_after_batch_transfer(self, batch, dataloader_idx):
        print("-" * 21)
        print(batch["image"].device)
        print(self.train_aug[0].device)
        print("-" * 21)
        if self.trainer:
            if self.trainer.training:
                batch = self.train_aug(batch)
                batch["mask"] = batch["mask"].squeeze()  # keepdim not working in 0.7.3?
        return batch


dm = DummyDataModule()

task = SemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    in_channels=4,
    num_classes=2
)

trainer = pl.Trainer(max_epochs=1, accelerator="gpu", devices=[2])
trainer.fit(task, dm)

If you run the above on 0.7.4 you get the same error that @MathiasBaumgartinger reported.

@adamjstewart if you run this on 0.7.3 you can see that dm.train_aug is still on the CPU, so I don't think our augmentations are working as expected period.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

5 participants