diff --git a/s2s2net/model.py b/s2s2net/model.py index 0768e30..7df653e 100644 --- a/s2s2net/model.py +++ b/s2s2net/model.py @@ -18,6 +18,7 @@ import mmseg.models import numpy as np import pytorch_lightning as pl +import pytorch_lightning.utilities.deepspeed import rasterio import rioxarray import skimage.exposure @@ -173,7 +174,7 @@ def forward(self, x: torch.Tensor) -> typing.Dict[str, torch.Tensor]: } def evaluate( - self, batch: typing.Dict[str, torch.Tensor] + self, batch: typing.Dict[str, torch.Tensor], calc_loss: bool = True ) -> typing.Dict[str, torch.Tensor]: """ Compute the loss for a single batch in the training or validation step. @@ -199,83 +200,98 @@ def evaluate( dtype = torch.float16 if self.precision == 16 else torch.float x: torch.Tensor = batch["image"].to(dtype=dtype) # Input Sentinel-2 image y: torch.Tensor = batch["mask"] # Groundtruth binary mask - y_highres: torch.Tensor = batch["hres"] # High resolution image + if calc_loss: + y_highres: torch.Tensor = batch["hres"] # High resolution image # y = torch.randn(8, 1, 2560, 2560) # y_highres = torch.randn(8, 4, 2560, 2560) y_hat: typing.Dict[str, torch.Tensor] = self(x) ## Calculate loss values to minimize - def similarity_matrix(f): - # f expected shape (Bs, C', H', W') - # before computing the relationship of every pair of pixels, - # subsample the feature map to its 1/8 - f = F.interpolate( - f, size=(f.shape[2] // 8, f.shape[3] // 8), mode="nearest" + if calc_loss: # only on training and val step + + def similarity_matrix(f): + # f expected shape (Bs, C', H', W') + # before computing the relationship of every pair of pixels, + # subsample the feature map to its 1/8 + f = F.interpolate( + f, size=(f.shape[2] // 8, f.shape[3] // 8), mode="nearest" + ) + f = f.permute((0, 2, 3, 1)) + f = torch.reshape( + f, (f.shape[0], -1, f.shape[3]) + ) # shape (Bs, H'xW', C') + f_n = torch.linalg.norm(f, ord=None, dim=2).unsqueeze( + -1 + ) # ord=None indicates 2-Norm, + # unsqueeze last dimension to broadcast later + eps = 1e-8 + f_norm = f / torch.max(f_n, eps * torch.ones_like(f_n)) + sim_mt = f_norm @ f_norm.transpose(2, 1) + return sim_mt + + # 1: Feature Affinity loss calculation + _segmmask_sim_matrix: torch.Tensor = similarity_matrix( + f=y_hat["segmmask_conv_output_0"] + ) + _superres_sim_matrix: torch.Tensor = similarity_matrix( + f=y_hat["superres_conv_output_0"] + ) + _n_elements: int = ( + _segmmask_sim_matrix.shape[-2] * _segmmask_sim_matrix.shape[-1] + ) + _abs_dist: torch.Tensor = torch.abs( + _segmmask_sim_matrix - _superres_sim_matrix + ) + feature_affinity_loss: torch.Tensor = torch.mean( + (1 / _n_elements) * torch.sum(input=_abs_dist, dim=[-2, -1]) ) - f = f.permute((0, 2, 3, 1)) - f = torch.reshape(f, (f.shape[0], -1, f.shape[3])) # shape (Bs, H'xW', C') - f_n = torch.linalg.norm(f, ord=None, dim=2).unsqueeze( - -1 - ) # ord=None indicates 2-Norm, - # unsqueeze last dimension to broadcast later - eps = 1e-8 - f_norm = f / torch.max(f_n, eps * torch.ones_like(f_n)) - sim_mt = f_norm @ f_norm.transpose(2, 1) - return sim_mt - - # 1: Feature Affinity loss calculation - _segmmask_sim_matrix: torch.Tensor = similarity_matrix( - f=y_hat["segmmask_conv_output_0"] - ) - _superres_sim_matrix: torch.Tensor = similarity_matrix( - f=y_hat["superres_conv_output_0"] - ) - _n_elements: int = ( - _segmmask_sim_matrix.shape[-2] * _segmmask_sim_matrix.shape[-1] - ) - _abs_dist: torch.Tensor = torch.abs(_segmmask_sim_matrix - _superres_sim_matrix) - feature_affinity_loss: torch.Tensor = torch.mean( - (1 / _n_elements) * torch.sum(input=_abs_dist, dim=[-2, -1]) - ) - # 2: Semantic Segmentation loss (Focal Loss) - segmmask_loss: torch.Tensor = torchvision.ops.sigmoid_focal_loss( - inputs=y_hat["segmmask_conv_output_1"], - targets=y, - alpha=0.75, - gamma=2, - reduction="mean", - ) - # 3: Super-Resolution loss (Mean Absolute Error) - superres_loss: torch.Tensor = torchmetrics.functional.mean_absolute_error( - preds=y_hat["superres_conv_output_1"], - target=y_highres.to(dtype=torch.float16), - ) + # 2: Semantic Segmentation loss (Focal Loss) + segmmask_loss: torch.Tensor = torchvision.ops.sigmoid_focal_loss( + inputs=y_hat["segmmask_conv_output_1"], + targets=y, + alpha=0.75, + gamma=2, + reduction="mean", + ) + # 3: Super-Resolution loss (Mean Absolute Error) + superres_loss: torch.Tensor = torchmetrics.functional.mean_absolute_error( + preds=y_hat["superres_conv_output_1"], + target=y_highres.to(dtype=torch.float16), + ) - # 1 + 2 + 3: Calculate total loss and log to console - total_loss: torch.Tensor = ( - (1.0 * feature_affinity_loss) + segmmask_loss + (0.001 * superres_loss) - ) - losses: typing.Dict[str, torch.Tensor] = { - # Component losses (Feature Affinity, Segmentation, Super-Resolution) - "loss_feataffy": feature_affinity_loss.detach(), - "loss_segmmask": segmmask_loss.detach(), - "loss_superres": superres_loss.detach(), - } + # 1 + 2 + 3: Calculate total loss and log to console + total_loss: torch.Tensor = ( + (1.0 * feature_affinity_loss) + segmmask_loss + (0.001 * superres_loss) + ) + losses: typing.Dict[str, torch.Tensor] = { + # Total loss + "loss": total_loss, + # Component losses (Feature Affinity, Segmentation, Super-Resolution) + "loss_feataffy": feature_affinity_loss.detach(), + "loss_segmmask": segmmask_loss.detach(), + "loss_superres": superres_loss.detach(), + } + else: # if calc_loss is False, i.e. only on test step + losses: dict = {} # Calculate metrics to determine how good results are + preds = y_hat["segmmask_conv_output_1"] + target = (y > 0.5).to(dtype=torch.int8) # binarize + if preds.shape != target.shape: # resize prediction to target shape + preds = F.interpolate(input=preds, size=target.shape[-2:], mode="bilinear") + # print(x.shape, preds.shape, target.shape) + iou_score: torch.Tensor = self.iou( # Intersection over Union - preds=y_hat["segmmask_conv_output_1"].squeeze(), - target=(y > 0.5).squeeze().to(dtype=torch.int8), # binarize + preds=preds.squeeze(), target=target.squeeze() ) f1_score: torch.Tensor = self.f1_score( # F1 Score - preds=y_hat["segmmask_conv_output_1"].ravel(), - target=(y > 0.5).ravel().to(dtype=torch.int8), # binarize + preds=preds.ravel(), target=target.ravel() ) metrics: typing.Dict[str, torch.Tensor] = {"iou": iou_score, "f1": f1_score} - return {"loss": total_loss, **losses, **metrics} + return {**losses, **metrics} def training_step( self, batch: typing.Dict[str, torch.Tensor], batch_idx: int @@ -344,7 +360,7 @@ def predict_step( y_hat: typing.Dict[str, torch.Tensor] = self(x) segmmask: torch.Tensor = torch.sigmoid(input=y_hat["segmmask_conv_output_1"]) superres: torch.Tensor = y_hat["superres_conv_output_1"] - _, bands, height, width = superres.shape + _, bands, height, width = segmmask.shape try: # Coordintate Reference System of input image @@ -393,6 +409,31 @@ def predict_step( return results + def test_step( + self, batch: typing.Dict[str, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """ + Logic for the neural network's test loop. + """ + test_metrics: dict = self.evaluate(batch=batch, calc_loss=False) + + self.log_dict( + dictionary={f"test_{key}": value for key, value in test_metrics.items()}, + prog_bar=True, + ) + + # Log test metrics to Tensorboard + if self.logger is not None and hasattr(self.logger.experiment, "add_scalars"): + for metric_name, metric_value in test_metrics.items(): + self.logger.experiment.add_scalars( + main_tag=metric_name, + tag_scalar_dict={"test": metric_value}, + global_step=self.global_step, + # epoch=self.current_epoch, + ) + + return test_metrics["f1"] + def configure_optimizers(self): """ Optimizing function used to reduce the loss, so that the predicted @@ -433,6 +474,7 @@ def __init__( [typing.Dict[str, torch.Tensor]], typing.Dict[str, torch.Tensor] ] ] = None, + ids: typing.Optional[typing.List[str]] = None, ): self.root: str = root self.train: bool = train @@ -441,7 +483,7 @@ def __init__( img_path: str = ( os.path.join(self.root, "image") if self.train else os.path.join(self.root) ) - self.ids: list = [int(id) for id, _ in enumerate(os.listdir(path=img_path))] + self.ids: list[str] = ids or [path for path in os.listdir(path=img_path)] def __getitem__( self, index: int = 0 @@ -474,21 +516,57 @@ def __getitem__( sample: dict = {"image": image, "mask": mask, "hres": hres} - else: - filename: str = glob.glob( - os.path.join(self.root, f"{index:04d}", "S2*.tif") - )[0] - with rioxarray.open_rasterio(filename=filename) as rds: - assert rds.ndim == 3 # Channel, Height, Width - assert rds.shape[0] == 6 # 6 bands/channels (RGB+NIR+SWIR) - left, bottom, right, top = rds.rio.bounds() - sample: dict = { - "image": torch.as_tensor(data=rds.data.astype(np.int16)), - "crs": rds.rio.crs, - "bbox": rasterio.coords.BoundingBox( - left=left, right=right, bottom=bottom, top=top - ), - } + else: # if self.train is False, i.e. for predict and test dataloader + idx: str = self.ids[index] # e.g. 0123 + image_filename: str = glob.glob(os.path.join(self.root, idx, "S2*.tif"))[0] + with rioxarray.open_rasterio(filename=image_filename) as rds_image: + assert rds_image.ndim == 3 # Channel, Height, Width + assert rds_image.shape[0] == 6 # 6 bands/channels (RGB+NIR+SWIR) + image = rds_image + sample: dict = {"crs": image.rio.crs} + + # For test dataloader, also need to get mask to compute metrics + try: + mask_filename: str = glob.glob( + os.path.join(self.root, idx, "*_mask_*.tif") + )[0] + with rioxarray.open_rasterio(filename=mask_filename) as rds_mask: + assert rds_mask.ndim == 3 # Channel, Height, Width + assert rds_mask.shape[0] == 1 # 1 band/channel + + # Clip to bounding box extent of mask with non-NaN values + # Need to use low-res (10m) extent instead of 2m extent + mask_extent = ( + rds_mask.rio.reproject( + dst_crs=image.rio.crs, resolution=image.rio.resolution() + ) + .where( + cond=~rds_mask.isnull(), # keep non-NaN areas + # cond=rds_mask == 1,# keep with-valid-pixel areas + drop=True, + ) + .rio.bounds() + ) + sample["mask"] = torch.as_tensor( + data=rds_mask.rio.clip_box(*mask_extent).data # float32 + ) + + # Clip image to match geographical extent of binary mask + assert rds_mask.rio.crs == rds_image.rio.crs + image = image.rio.clip_box(*mask_extent) + + except IndexError: # if no mask in directory, don't add to sample + pass + + left, bottom, right, top = image.rio.bounds() + sample["bbox"] = rasterio.coords.BoundingBox( + left=left, right=right, bottom=bottom, top=top + ) + sample["image"] = torch.as_tensor( + data=image.data.astype(np.int16) # uint16 to int16 + ) + # assert sample["mask"].shape[1] == sample["image"].shape[1] * 5 + # assert sample["mask"].shape[2] == sample["image"].shape[2] * 5 if self.transforms is not None: sample: typing.Dict[str, torch.Tensor] = self.transforms(sample) @@ -548,6 +626,17 @@ def setup(self, stage: typing.Optional[str] = None) -> torch.utils.data.Dataset: self.dataset: torch.utils.data.Dataset = S2S2Dataset( root="SuperResolution/aligned", train=False ) + elif stage == "test": # Inference on test images + self.dataset: torch.utils.data.Dataset = S2S2Dataset( + root="SuperResolution/aligned", + train=False, + ids=["0123", "0124", "0125", "0126", "0211", "0223", "0157", "0439"], + ) + else: + raise ValueError( + f"Unknown stage: {stage}, " + "should be either 'fit', 'predict' or 'test'" + ) return self.dataset @@ -565,7 +654,6 @@ def val_dataloader(self) -> torch.utils.data.DataLoader: Loads the data used in the validation loop. Set the validation batch size here too. """ - # TODO use an independent validation set from different geographic region return torch.utils.data.DataLoader( dataset=self.dataset_val, batch_size=32, num_workers=4 ) @@ -582,11 +670,28 @@ def predict_dataloader(self) -> torch.utils.data.DataLoader: return torch.utils.data.DataLoader( dataset=self.dataset, batch_size=1, - num_workers=1, + num_workers=4, collate_fn=torchgeo.datasets.stack_samples, ) - # for batch in torch.utils.data.DataLoader(dataset=self.dataset, batch_size=1): + def test_dataloader(self) -> torch.utils.data.DataLoader: + """ + Loads the data used in the test loop. + Set the test batch size here too. + """ + return torch.utils.data.DataLoader( + dataset=self.dataset, + batch_size=1, + num_workers=4, + collate_fn=torchgeo.datasets.stack_samples, + ) + + # for batch in torch.utils.data.DataLoader( + # dataset=self.dataset, + # batch_size=1, + # num_workers=1, + # collate_fn=torchgeo.datasets.stack_samples, + # ): # break @@ -619,6 +724,10 @@ def cli_main(): tensorboard_logger: pl.loggers.LightningLoggerBase = pl.loggers.TensorBoardLogger( save_dir="tb_logs", name="s2s2net" ) + # Setup automatic checkpointing of best model + checkpoint_callback = pl.callbacks.ModelCheckpoint( + filename="{epoch}-{val_f1:.2f}-{step}", monitor="val_f1", mode="max" + ) # Training # TODO contribute to pytorch lightning so that deterministic="warn" works @@ -628,17 +737,29 @@ def cli_main(): trainer: pl.Trainer = pl.Trainer( # deterministic=True, accelerator="auto", + callbacks=[checkpoint_callback], devices="auto", strategy="deepspeed_stage_2", logger=tensorboard_logger, - max_epochs=27, + max_epochs=52, precision=16, ) - trainer.fit(model=model, datamodule=datamodule) + # Testing + if trainer.num_devices > 1: + trainer.test(model=model, datamodule=datamodule) + # Export Model - trainer.save_checkpoint(filepath="s2s2net.ckpt") + # Convert deepspeed checkpoint directory to single checkpoint file + # https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/model_parallel.html#deepspeed-zero-stage-3-single-file + # trainer.save_checkpoint(filepath="s2s2net_ckpt") + if checkpoint_callback.best_model_path: + print(f"Saving {checkpoint_callback.best_model_path} to s2s2net.ckpt") + pl.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict( + checkpoint_dir=checkpoint_callback.best_model_path, + output_file="s2s2net.ckpt", + ) print("Done!") diff --git a/tests/test_model.py b/tests/test_model.py index 05b4458..81868be 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ import s2s2net.model - +# %% class RandomDataset(torch.utils.data.Dataset): def __init__(self): pass @@ -25,6 +25,7 @@ def __getitem__(self, idx: int) -> dict: } +# %% def test_s2s2net(): """ Run a full train, val, test and prediction loop using 1 batch. @@ -36,11 +37,17 @@ def test_s2s2net(): model: pl.LightningModule = s2s2net.model.S2S2Net() # Training - trainer: pl.Trainer = pl.Trainer(accelerator="auto", fast_dev_run=True) + trainer: pl.Trainer = pl.Trainer(accelerator="auto", devices=1, fast_dev_run=True) trainer.fit(model=model, train_dataloaders=dataloader) - # Test inference - predictions = trainer.predict(model=model, dataloaders=dataloader) + # Inference/Prediction + predictions: list = trainer.predict(model=model, dataloaders=dataloader) segmmask, superres = predictions[0] assert segmmask.shape == (1, 1, 2560, 2560) assert superres.shape == (1, 4, 2560, 2560) + + # Test/Evaluation + scores: list[dict] = trainer.test(model=model, dataloaders=dataloader) + assert len(scores[0].keys()) == 2 + assert scores[0]["test_f1"] >= 0.0 + assert scores[0]["test_iou"] > 0