-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] (DataLoader
) sanity check fails due to Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)
#20456
Comments
After some debugging, I found that indeed, the batches were not on the GPU during the different steps. When I add AFAIK, this should not be necessary, as a EDIT: Tagging @adamjstewart as this might be related to torchgeo |
Hit this too. Not sure how significant it is that the batch is a dict |
Hey @MathiasBaumgartinger it definitely isn't needed in standard PyTorch Lightning code. I'm not sure how dataloading is managed on the torchgeo side in detail though, maybe @adamjstewart you have a hint here? |
The only hacky thing TorchGeo does is to override def transfer_batch_to_device(self, batch, device, dataloader_idx):
# Non-Tensor values cannot be moved to a device
del batch['crs']
del batch['bounds']
batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
return batch But that's only for @MathiasBaumgartinger @robmarkcole maybe you can provide more details on how to reproduce this, I haven't personally hit this issue yet. |
I encountered this when using a litdata streaming dataset, serialised via def fetch_and_serialize_fn(index):
"""Serialize as bytes."""
sample = dataset[index] # Fetch the sample
data = {
"image_id": sample["image_id"],
"mask": sample["mask_path"],
"image": sample["image_path"],
}
return data Which is then handled by a streaming dataset: class BaseStreamingDataset(StreamingDataset):
"""
Base class for streaming datasets.
Args:
input_dir: Local directory or S3 location of the dataset
transforms: A transform that takes in an image and returns a transformed version.
band_indices: List of band indices to read from the dataset.
"""
def __init__(
self,
*args,
transforms: Optional[Callable] = None,
band_indices: Optional[List[int]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.transforms = transforms
self.band_indices = band_indices
def _load_image(self, image_data):
"""Load image data from bytes or NumPy arrays and process it."""
if isinstance(image_data, bytes):
# Load image data from bytes using rasterio
with MemoryFile(image_data) as memfile:
with memfile.open() as dataset:
if self.band_indices is None:
image = dataset.read()
else:
image = dataset.read(
indexes=[i + 1 for i in self.band_indices]
) # Rasterio indexes start at 1
elif isinstance(image_data, np.ndarray):
# Require channels first
image_data = move_smallest_dim_to_first(image_data)
if self.band_indices is not None:
image = image_data[self.band_indices]
else:
raise ValueError(f"Unsupported image data type: {type(image_data)}")
image = torch.from_numpy(image.astype(np.float32)).float()
image = self._impute_nans(image)
return image
def _load_mask(self, mask_data):
"""Load mask data from bytes or PyTorch tensor."""
if isinstance(mask_data, bytes):
with MemoryFile(mask_data) as memfile:
with memfile.open() as dataset:
mask = dataset.read()
elif isinstance(mask_data, np.ndarray):
# 2D mask expected
mask = mask_data
else:
raise ValueError(f"Unsupported mask data type: {type(mask_data)}")
mask = torch.from_numpy(mask).long()
return mask
def _impute_nans(self, image):
if torch.isnan(image).any():
for channel in range(image.shape[0]):
mean_val = image[channel].nanmean()
image[channel] = torch.where(
torch.isnan(image[channel]), mean_val, image[channel]
)
return image
def plot(self, sample: dict[str, Tensor]):
raise NotImplementedError("This method should be implemented by subclasses")
class SegmentationStreamingDataset(BaseStreamingDataset):
"""
Segmentation dataset with streaming support.
Args:
input_dir: Local directory or S3 location of the dataset
transforms: A transform that takes in an image and returns a transformed version.
bias_mask: A value to apply to the mask, applied before remapping.
remap_values: Dictionary to remap values in the segmentation mask.
binarise_mask: If True, the mask is binarized to 0 and 1.
"""
def __init__(
self,
*args,
bias_mask: Optional[int] = None,
remap_values: Optional[Dict[int, int]] = None,
binarise_mask: Optional[bool] = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.bias_mask = bias_mask
self.remap_values = remap_values
self.binarise_mask = binarise_mask
def __getitem__(self, index) -> dict:
data = super().__getitem__(index)
sample = {}
image = self._load_image(data["image"])
mask = self._load_mask(data["mask"])
if self.bias_mask:
mask = mask + self.bias_mask
if self.remap_values:
for original_value, new_value in self.remap_values.items():
mask[mask == original_value] = new_value
if self.binarise_mask:
mask[mask > 0] = 1
if mask.dim() == 3:
mask = mask.squeeze(0)
if torch.isnan(mask).any():
raise ValueError(f"NaNs found in mask for index {index}")
sample["image"] = image
sample["mask"] = mask
if self.transforms is not None:
sample = self.transforms(sample)
if "image_id" in data.keys():
sample["image_id"] = data["image_id"]
return sample Via a streaming datamodule: """
Use StreamingNonGeoDataModule when working with non-geospatial datasets within
the TorchGeo framework, leveraging TorchGeo's NonGeoDataModule functionalities.
"""
import re
from typing import Any, Dict, List, Optional, Type
from litdata import StreamingDataLoader
from torch.utils.data import Dataset, default_collate
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.datasets import NonGeoDataset
from upath import UPath
from upath.implementations.cloud import S3Path
class StreamingNonGeoDataModule(NonGeoDataModule):
"""
The dataloaders are of type StreamingDataLoader compatible with litdata.
This datamodule requires that root contains train, val and test folders.
"""
def __init__(
self,
dataset_class: Type[NonGeoDataset],
root: str = "data",
batch_size: int = 1,
num_workers: int = 0,
band_indices: Optional[List[int]] = None,
bias_mask: Optional[int] = None,
remap_values: Optional[Dict[int, int]] = None,
binarise_mask: Optional[bool] = False,
prefetch_factor: int = 2,
persistent_workers: bool = False,
multiprocessing_context: Optional[Any] = None,
) -> None:
"""
Args:
dataset_class: Class used to instantiate a new dataset.
root: Path to the input directory or s3 location which contains train, val and test folders.
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
band_indices: List of band indices to read from the dataset.
bias_mask: A value to apply to the mask, applied before remapping.
remap_values: Dictionary to remap values in the mask.
binarise_mask: If True, the mask is binarized to 0 and 1.
prefetch_factor: Number of batches to prefetch.
persistent_workers: If True, workers are not shutdown after each epoch.
multiprocessing_context: Context for multiprocessing
"""
super().__init__(dataset_class, batch_size, num_workers)
self.prefetch_factor = prefetch_factor
self.multiprocessing_context = multiprocessing_context
self.persistent_workers = persistent_workers
if not root.endswith("/"):
root = root + "/"
if isinstance(UPath(root), S3Path):
print("root is S3Path")
self.verify_s3_folder(root)
self.root = root
self.band_indices = band_indices
self.bias_mask = bias_mask
self.remap_values = remap_values
self.binarise_mask = binarise_mask
self.collate_fn = default_collate
@staticmethod
def verify_s3_folder(s3_url: str) -> None:
"""Verify the existence of an S3 'folder' (prefix) using upath."""
try:
path = UPath(s3_url)
# Check if the directory is not empty (i.e., it has contents)
contents = list(path.iterdir())
if not contents:
raise FileNotFoundError(f"S3 'folder' not found or is empty: {s3_url}")
except Exception as e:
raise RuntimeError(f"Failed to verify S3 'folder': {s3_url}. Error: {e}")
def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage == "fit":
self.train_dataset = self.create_dataset(subset="train")
self.val_dataset = self.create_dataset(subset="val")
elif stage == "validate":
self.val_dataset = self.create_dataset(subset="val")
elif stage == "test":
self.test_dataset = self.create_dataset(subset="test")
def create_dataset(self, subset: str) -> Dataset:
"""Create dataset instance with appropriate parameters.
Args:
subset: Either 'train', 'val', or 'test'.
Returns:
Dataset instance.
"""
dataset_kwargs: Dict[str, Any] = {"input_dir": f"{self.root}{subset}"}
if self.band_indices is not None:
dataset_kwargs["band_indices"] = self.band_indices
if self.bias_mask is not None:
dataset_kwargs["bias_mask"] = self.bias_mask
if self.remap_values is not None:
dataset_kwargs["remap_values"] = self.remap_values
if self.binarise_mask:
dataset_kwargs["binarise_mask"] = self.binarise_mask
return self.dataset_class(**dataset_kwargs)
def _dataloader_factory(self, split: str) -> StreamingDataLoader:
"""Implement one or more PyTorch DataLoaders.
Args:
split: Either 'train', 'val', 'test', or 'predict'.
"""
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
return StreamingDataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=split == "train",
num_workers=self.num_workers,
collate_fn=self.collate_fn,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
multiprocessing_context=self.multiprocessing_context,
) The lightning module has this logic: def unpack_batch(self, batch):
if isinstance(batch, tuple):
return batch # Assume tuple is in the form (x, y)
elif isinstance(batch, dict):
# https://github.com/Lightning-AI/pytorch-lightning/issues/20456
return batch["image"].to(self.device), batch["mask"].to(self.device) # Unpack from dict
elif isinstance(batch, list) and all(
isinstance(item, torch.Tensor) for item in batch
):
# Assuming the first tensor is images and the second is masks
return batch[0], batch[1]
else:
raise TypeError("Unsupported batch format") |
Hey all, another torchgeo user here. I hit this as well with 0.7.4. Here's a minimal reproduceable example:
If you run the above on 0.7.4 you get the same error that @MathiasBaumgartinger reported. @adamjstewart if you run this on 0.7.3 you can see that |
Bug description
Hi there! I have previously created my first
LightningDataModule
. More specifically, aNonGeoDataModule
which inherits from there (see torchgeo-fork. Interestingly, when I try to run this module I getRuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
. Even more intersting is the fact, that if I override thetransfer_batch_to_device
like:I get the output
It happens during the validation step (lightning/pytorch/strategies/strategy.py", line 411).
What version are you seeing the problem on?
v2.4
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: