Skip to content

Commit

Permalink
Move metatensor support into dev branch (#4562)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli authored Jul 5, 2022
1 parent a2d6346 commit 24cf761
Show file tree
Hide file tree
Showing 222 changed files with 4,660 additions and 3,450 deletions.
5 changes: 4 additions & 1 deletion monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

import numpy as np

from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
from monai.transforms import AsChannelFirstd, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd
from monai.utils import GridSampleMode
from monai.utils.enums import PostFix


def create_dataset(
Expand Down Expand Up @@ -128,6 +129,8 @@ def _default_transforms(image_key, label_key, pixdim):
AsChannelFirstd(keys=keys),
Orientationd(keys=keys, axcodes="RAS"),
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
FromMetaTensord(keys=keys),
ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]),
]
)

Expand Down
18 changes: 5 additions & 13 deletions monai/apps/detection/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ def __init__(self, zoom: Union[Sequence[float], float], keep_size: bool = False,
self.keep_size = keep_size
self.kwargs = kwargs

def __call__(
self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int, None] = None
) -> NdarrayOrTensor: # type: ignore
def __call__(self, boxes: torch.Tensor, src_spatial_size: Union[Sequence[int], int, None] = None):
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -266,9 +264,7 @@ def __init__(self, spatial_size: Union[Sequence[int], int], size_mode: str = "al
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]): # type: ignore
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -316,9 +312,7 @@ class FlipBox(Transform):
def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
self.spatial_axis = spatial_axis

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]): # type: ignore
"""
Args:
boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -489,7 +483,7 @@ def __init__(

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]
) -> Tuple[NdarrayOrTensor, Union[Tuple, NdarrayOrTensor]]:
):
"""
Args:
boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -535,9 +529,7 @@ class RotateBox90(Rotate90):
def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None:
super().__init__(k, spatial_axes)

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]): # type: ignore
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
Expand Down
8 changes: 4 additions & 4 deletions monai/apps/detection/transforms/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def apply_affine_to_boxes(boxes: NdarrayOrTensor, affine: NdarrayOrTensor) -> Nd
return boxes_affine


def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]) -> NdarrayOrTensor:
def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]):
"""
Zoom boxes
Expand Down Expand Up @@ -128,7 +128,7 @@ def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]) -> N

def resize_boxes(
boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int], dst_spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
):
"""
Resize boxes when the corresponding image is resized
Expand Down Expand Up @@ -262,7 +262,7 @@ def convert_box_to_mask(
boxes_only_mask = resizer(boxes_only_mask[None])[0] # type: ignore
else:
# generate a rect mask
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b]) # type: ignore
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
# apply to global mask
slicing = [b]
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
Expand Down Expand Up @@ -334,7 +334,7 @@ def select_labels(
Return:
selected labels, does not share memory with original labels.
"""
labels_tuple = ensure_tuple(labels, True) # type: ignore
labels_tuple = ensure_tuple(labels, True)

labels_select_list = []
keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0]
Expand Down
48 changes: 24 additions & 24 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ZoomBox,
)
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
from monai.config import KeysCollection
from monai.config import KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image
from monai.data.utils import orientation_ras_lps
Expand All @@ -43,7 +43,7 @@
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep
from monai.utils import InterpolateMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix, TraceKeys
from monai.utils.type_conversion import convert_data_type

Expand Down Expand Up @@ -90,8 +90,6 @@
]

DEFAULT_POST_FIX = PostFix.meta()
InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str]
PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str]


class ConvertBoxModed(MapTransform, InvertibleTransform):
Expand Down Expand Up @@ -377,8 +375,8 @@ def __init__(
box_keys: KeysCollection,
box_ref_image_keys: KeysCollection,
zoom: Union[Sequence[float], float],
mode: InterpolateModeSequence = InterpolateMode.AREA,
padding_mode: PadModeSequence = NumpyPadMode.EDGE,
mode: SequenceStr = InterpolateMode.AREA,
padding_mode: SequenceStr = NumpyPadMode.EDGE,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
keep_size: bool = True,
allow_missing_keys: bool = False,
Expand All @@ -395,7 +393,7 @@ def __init__(
self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs)
self.keep_size = keep_size

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)

# zoom box
Expand Down Expand Up @@ -431,7 +429,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand All @@ -453,7 +451,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
align_corners=None if align_corners == TraceKeys.NONE else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[TraceKeys.EXTRA_INFO]["original_shape"], mode="edge")(d[key])
orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"]
d[key] = SpatialPad(orig_shape, mode="edge")(d[key])

# zoom boxes
if key_type == "box_key":
Expand Down Expand Up @@ -518,8 +517,8 @@ def __init__(
prob: float = 0.1,
min_zoom: Union[Sequence[float], float] = 0.9,
max_zoom: Union[Sequence[float], float] = 1.1,
mode: InterpolateModeSequence = InterpolateMode.AREA,
padding_mode: PadModeSequence = NumpyPadMode.EDGE,
mode: SequenceStr = InterpolateMode.AREA,
padding_mode: SequenceStr = NumpyPadMode.EDGE,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
keep_size: bool = True,
allow_missing_keys: bool = False,
Expand All @@ -544,7 +543,7 @@ def set_random_state(
self.rand_zoom.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
Expand Down Expand Up @@ -594,7 +593,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand All @@ -616,7 +615,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
align_corners=None if align_corners == TraceKeys.NONE else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[TraceKeys.EXTRA_INFO]["original_shape"], mode="edge")(d[key])
orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"]
d[key] = SpatialPad(orig_shape, mode="edge")(d[key])

# zoom boxes
if key_type == "box_key":
Expand Down Expand Up @@ -661,7 +661,7 @@ def __init__(
self.flipper = Flip(spatial_axis=spatial_axis)
self.box_flipper = FlipBox(spatial_axis=self.flipper.spatial_axis)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)

for key in self.image_keys:
Expand All @@ -674,7 +674,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.push_transform(d, box_key, extra_info={"spatial_size": spatial_size, "type": "box_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -735,7 +735,7 @@ def set_random_state(
self.flipper.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
self.randomize(None)

Expand All @@ -751,7 +751,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.push_transform(d, box_key, extra_info={"spatial_size": spatial_size, "type": "box_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def randomize( # type: ignore
self.allow_smaller,
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]:
d = dict(data)
spatial_dims = len(d[self.image_keys[0]].shape) - 1
image_size = d[self.image_keys[0]].shape[1:]
Expand All @@ -1190,7 +1190,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab
raise ValueError("no available ROI centers to crop.")

# initialize returned list with shallow copy to preserve key ordering
results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)]
results: List[Dict[Hashable, torch.Tensor]] = [dict(d) for _ in range(self.num_samples)]

# crop images and boxes for each center.
for i, center in enumerate(self.centers):
Expand Down Expand Up @@ -1255,7 +1255,7 @@ def __init__(
self.img_rotator = Rotate90(k, spatial_axes)
self.box_rotator = RotateBox90(k, spatial_axes)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]:
d = dict(data)
for key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
spatial_size = list(d[box_ref_image_key].shape[1:])
Expand All @@ -1273,7 +1273,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
self.push_transform(d, key, extra_info={"type": "image_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def __init__(
super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys)
self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]:
self.randomize()
d = dict(data)

Expand Down Expand Up @@ -1359,7 +1359,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))
if self._rand_k % 4 == 0:
return d
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def pad_images(
if max(pt_pad_width) == 0:
# if there is no need to pad
return input_images, [orig_size] * input_images.shape[0]
mode_: str = convert_pad_mode(dst=input_images, mode=mode).value
mode_: str = convert_pad_mode(dst=input_images, mode=mode)
return F.pad(input_images, pt_pad_width, mode=mode_, **kwargs), [orig_size] * input_images.shape[0]

# If input_images: List[Tensor])
Expand All @@ -151,7 +151,7 @@ def pad_images(
# Use `SpatialPad` to match sizes, padding in the end will not affect boxes
padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs)
for idx, img in enumerate(input_images):
images[idx, ...] = padder(img) # type: ignore
images[idx, ...] = padder(img)

return images, [list(ss) for ss in image_sizes]

Expand Down
23 changes: 10 additions & 13 deletions monai/apps/nuclick/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@

import math
import random
from enum import Enum
from typing import Any, Tuple, Union

import numpy as np

from monai.config import KeysCollection
from monai.transforms import MapTransform, Randomizable, SpatialPad
from monai.utils import optional_import
from monai.utils import StrEnum, optional_import

measure, _ = optional_import("skimage.measure")
morphology, _ = optional_import("skimage.morphology")


class NuclickKeys(Enum):
class NuclickKeys(StrEnum):
"""
Keys for nuclick transforms.
"""
Expand Down Expand Up @@ -83,7 +82,7 @@ class ExtractPatchd(MapTransform):
def __init__(
self,
keys: KeysCollection,
centroid_key: str = NuclickKeys.CENTROID.value,
centroid_key: str = NuclickKeys.CENTROID,
patch_size: Union[Tuple[int, int], int] = 128,
allow_missing_keys: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -138,9 +137,9 @@ class SplitLabeld(MapTransform):
def __init__(
self,
keys: KeysCollection,
# label: str = NuclickKeys.LABEL.value,
others: str = NuclickKeys.OTHERS.value,
mask_value: str = NuclickKeys.MASK_VALUE.value,
# label: str = NuclickKeys.LABEL,
others: str = NuclickKeys.OTHERS,
mask_value: str = NuclickKeys.MASK_VALUE,
min_area: int = 5,
):

Expand Down Expand Up @@ -268,9 +267,9 @@ class AddPointGuidanceSignald(Randomizable, MapTransform):

def __init__(
self,
image: str = NuclickKeys.IMAGE.value,
label: str = NuclickKeys.LABEL.value,
others: str = NuclickKeys.OTHERS.value,
image: str = NuclickKeys.IMAGE,
label: str = NuclickKeys.LABEL,
others: str = NuclickKeys.OTHERS,
drop_rate: float = 0.5,
jitter_range: int = 3,
):
Expand Down Expand Up @@ -338,9 +337,7 @@ class AddClickSignalsd(MapTransform):
bb_size: single integer size, defines a bounding box like (bb_size, bb_size)
"""

def __init__(
self, image: str = NuclickKeys.IMAGE.value, foreground: str = NuclickKeys.FOREGROUND.value, bb_size: int = 128
):
def __init__(self, image: str = NuclickKeys.IMAGE, foreground: str = NuclickKeys.FOREGROUND, bb_size: int = 128):
self.image = image
self.foreground = foreground
self.bb_size = bb_size
Expand Down
Loading

0 comments on commit 24cf761

Please sign in to comment.