Skip to content

Commit

Permalink
Update vision mamba backbone to latest updates (#401)
Browse files Browse the repository at this point in the history
Update vision mamba backbone to latest updates
  • Loading branch information
anwai98 authored Nov 13, 2024
1 parent 1c251af commit 00c7218
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 212 deletions.
70 changes: 23 additions & 47 deletions experiments/vision-mamba/vimunet/run_cremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch_em
from torch_em.loss import DiceLoss
from torch_em.util import segmentation
from torch_em.data import MinInstanceSampler
from torch_em.model import get_vimunet_model
from torch_em.data.datasets import get_cremi_loader
from torch_em.util.prediction import predict_with_halo
Expand All @@ -28,70 +27,49 @@
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original"


def get_loaders(args, patch_shape=(1, 512, 512)):
def get_loaders(input):
train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]}
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]}

sampler = MinInstanceSampler()

train_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=2,
rois=train_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
val_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=1,
rois=val_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
kwargs = {
"path": input,
"patch_shape": (1, 512, 512),
"ndim": 2,
"label_dtype": torch.float32,
"defect_augmentation_kwargs": None,
"boundaries": True,
"num_workers": 16,
"download": True,
"shuffle": True,
}

train_loader = get_cremi_loader(batch_size=2, rois=train_rois, **kwargs)
val_loader = get_cremi_loader(batch_size=1, rois=val_rois, **kwargs)
return train_loader, val_loader


def run_cremi_training(args):
# the dataloaders for cremi dataset
train_loader, val_loader = get_loaders(args)
train_loader, val_loader = get_loaders(input=args.input)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True
)

model = get_vimunet_model(out_channels=1, model_type=args.model_type, with_cls_token=True)
save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type)

# loss function
loss = DiceLoss()

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
name="cremi-vimunet",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
loss=loss,
metric=loss,
loss=DiceLoss(),
metric=DiceLoss(),
log_image_interval=50,
save_root=save_root,
compile_model=False,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10}
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10},
mixed_precision=False,
)
trainer.fit(iterations=int(1e5))

Expand All @@ -102,10 +80,7 @@ def run_cremi_inference(args, device):

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint
out_channels=1, model_type=args.model_type, with_cls_token=True, checkpoint=checkpoint
)

all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif"))
Expand Down Expand Up @@ -134,6 +109,7 @@ def run_cremi_inference(args, device):
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}

res_path = os.path.join(args.result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
Expand Down
104 changes: 37 additions & 67 deletions experiments/vision-mamba/vimunet/run_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,93 +22,66 @@
ROOT = "/scratch/usr/nimanwai"


def get_loaders(args, patch_shape=(512, 512)):
if args.distances:
def get_loaders(input, boundaries, distances):
label_trafo = None
if distances:
label_trafo = torch_em.transform.label.PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=25
distances=True, boundary_distances=True, directed_distances=False, foreground=True, min_size=25,
)
else:
label_trafo = None

train_loader = get_livecell_loader(
path=args.input,
split="train",
patch_shape=patch_shape,
batch_size=2,
label_dtype=torch.float32,
boundaries=args.boundaries,
label_transform=label_trafo,
num_workers=16,
download=True,
)
val_loader = get_livecell_loader(
path=args.input,
split="val",
patch_shape=patch_shape,
batch_size=1,
label_dtype=torch.float32,
boundaries=args.boundaries,
label_transform=label_trafo,
num_workers=16,
download=True,
)

kwargs = {
"path": input,
"patch_shape": (512, 512),
"label_dtype": torch.float32,
"boundaries": boundaries,
"label_transform": label_trafo,
"num_workers": 16,
"download": True,
"shuffle": True,
}

train_loader = get_livecell_loader(split="train", batch_size=2, **kwargs)
val_loader = get_livecell_loader(split="val", batch_size=1, **kwargs)
return train_loader, val_loader


def get_output_channels(args):
if args.boundaries:
def get_output_channels(boundaries):
if boundaries:
output_channels = 2
else:
output_channels = 3

return output_channels


def get_loss_function(args):
if args.distances:
def get_loss_function(distances):
if distances:
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)

else:
loss = DiceLoss()

return loss


def get_save_root(args):
def get_save_root(boundaries, model_type, save_root):
# experiment_type
if args.boundaries:
if boundaries:
experiment_type = "boundaries"
else:
experiment_type = "distances"

model_name = args.model_type

# saving the model checkpoints
save_root = os.path.join(args.save_root, "scratch", experiment_type, model_name)
save_root = os.path.join(save_root, "scratch", experiment_type, model_type)
return save_root


def run_livecell_training(args):
# the dataloaders for livecell dataset
train_loader, val_loader = get_loaders(args)

output_channels = get_output_channels(args)
train_loader, val_loader = get_loaders(input=args.input, boundaries=args.boundaries, distances=args.distances)
output_channels = get_output_channels(boundaries=args.boundaries)
loss = get_loss_function(distances=args.distances)
save_root = get_save_root(boundaries=args.boundaries, model_type=args.model_type, save_root=args.save_root)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=output_channels,
model_type=args.model_type,
with_cls_token=True,
)

save_root = get_save_root(args)

# loss function
loss = get_loss_function(args)
model = get_vimunet_model(out_channels=output_channels, model_type=args.model_type, with_cls_token=True)

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
Expand All @@ -122,24 +95,20 @@ def run_livecell_training(args):
log_image_interval=50,
save_root=save_root,
compile_model=False,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10}
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10},
mixed_precision=False,
)
trainer.fit(iterations=int(1e5))


def run_livecell_inference(args, device):
output_channels = get_output_channels(args)

save_root = get_save_root(args)

output_channels = get_output_channels(boundaries=args.boundaries)
save_root = get_save_root(boundaries=args.boundaries, model_type=args.model_type, save_root=args.save_root)
checkpoint = os.path.join(save_root, "checkpoints", "livecell-vimunet", "best.pt")

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=output_channels,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint,
out_channels=output_channels, model_type=args.model_type, with_cls_token=True, checkpoint=checkpoint,
)

# the splits are provided with the livecell dataset
Expand All @@ -149,7 +118,7 @@ def run_livecell_inference(args, device):
all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*"))

msa_list, sa50_list, sa75_list = [], [], []
for label_path in tqdm(all_test_labels):
for label_path in tqdm(all_test_labels, desc="Prediction for LIVECell"):
labels = imageio.imread(label_path)
image_id = os.path.split(label_path)[-1]

Expand Down Expand Up @@ -184,6 +153,7 @@ def run_livecell_inference(args, device):
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}

res_path = os.path.join(args.result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
Expand Down
2 changes: 1 addition & 1 deletion torch_em/data/datasets/medical/cbis_ddsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_cbis_ddsm_paths(
path: Union[os.PathLike, str],
split: Literal['Train', 'Val', 'Test'],
task: Literal['Calc', 'Mass'],
tumour_type: Literal['MALIGNANT', 'BENIGN'],
tumour_type: Optional[Literal["MALIGNANT", "BENIGN"]] = None,
download: bool = False
):
"""Get paths to the CBIS DDSM data.
Expand Down
8 changes: 4 additions & 4 deletions torch_em/data/datasets/medical/piccolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def get_piccolo_data(path: Union[os.PathLike, str], download: bool = False) -> s
Returns:
Filepath where the data is downloaded.
"""
data_dir = os.path.join(path, r"piccolo dataset-release0.1")
if os.path.exists(data_dir):
return data_dir

if download:
raise NotImplementedError(
"Automatic download is not possible for this dataset. See 'get_piccolo_data' for details."
)

data_dir = os.path.join(path, r"piccolo dataset-release0.1")
if os.path.exists(data_dir):
return data_dir

rar_file = os.path.join(path, r"piccolo dataset_widefield-release0.1.rar")
if not os.path.exists(rar_file):
raise FileNotFoundError(
Expand Down
20 changes: 6 additions & 14 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch.nn as nn
import torch.nn.functional as F

from .unet import Decoder, ConvBlock2d, Upsampler2d
from .vit import get_vision_transformer
from .unet import Decoder, ConvBlock2d, Upsampler2d

try:
from micro_sam.util import get_sam_model
Expand All @@ -22,18 +22,15 @@
class UNETR(nn.Module):

def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

"""Function to load pretrained weights to the image encoder.
"""
if isinstance(checkpoint, str):
if backbone == "sam" and isinstance(encoder, str):
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
try:
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=checkpoint,
return_sam=True
)
_, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
encoder_state = model.image_encoder.state_dict()
except Exception:
# Try loading the encoder state directly from a checkpoint.
Expand All @@ -47,8 +44,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
k: v for k, v in encoder_state.items()
if (k != "mask_token" and not k.startswith("decoder"))
})

# let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
# Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
current_encoder_state = self.encoder.state_dict()
if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
del self.encoder.head
Expand All @@ -72,7 +68,7 @@ def __init__(
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True,
embed_dim: Optional[int] = None,
use_conv_transpose=True,
use_conv_transpose: bool = True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -150,15 +146,11 @@ def __init__(
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.base = ConvBlock2d(embed_dim, features_decoder[0])

self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)

self.deconv_out = _upsampler(
scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
)

self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

self.final_activation = self._get_activation(final_activation)

def _get_activation(self, activation):
Expand Down
Loading

0 comments on commit 00c7218

Please sign in to comment.