Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def sliding_window_inference(

Args:
inputs: input image to be processed (assuming NCHW[D])
roi_size: the spatial window size for inferences.
roi_size: the spatial window size for inferences, this must be a single value or a tuple with values
for each spatial dimension (eg. 2 for 2D, 3 for 3D).
When its components have None or non-positives, the corresponding inputs dimension will be used.
if the components of the `roi_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
Expand Down Expand Up @@ -131,11 +132,30 @@ def sliding_window_inference(
kwargs: optional keyword args to be passed to ``predictor``.

Note:
- input must be channel-first and have a batch dim, supports N-D sliding window.
- Inputs must be channel-first and have a batch dim (NCHW / NCDHW).
- If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream.

Raises:
ValueError: When the input dimensions do not match the expected dimensions based on ``roi_size``.

"""
buffered = buffer_steps is not None and buffer_steps > 0
num_spatial_dims = len(inputs.shape) - 2

# Only perform strict shape validation if roi_size is a sequence (explicit dimensions).
# If roi_size is an integer, it is broadcast to all dimensions, so we cannot
# infer the expected dimensionality to enforce a strict check here.
if isinstance(roi_size, Sequence):
roi_dims = len(roi_size)
if num_spatial_dims != roi_dims:
raise ValueError(
f"Inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size "
f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), "
f"but got inputs shape {inputs.shape}.\n"
"If you have channel-last data (e.g. B, D, H, W, C), please use "
"monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream."
)
# -----------------------------------------------------------------
buffered = buffer_steps is not None and buffer_steps > 0
if buffered:
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
Expand Down
20 changes: 20 additions & 0 deletions tests/inferers/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,26 @@ def compute_dict(data):
for rr, _ in zip(result_dict, expected_dict):
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)

def test_strict_shape_validation(self):
"""Test strict shape validation to ensure inputs match roi_size dimensions."""
device = "cpu"
roi_size = (16, 16, 16)
sw_batch_size = 4

def predictor(data):
return data

# Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel)
# 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here.
inputs_4d = torch.randn((1, 16, 16, 16), device=device)
with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"):
sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor)

# Case 2: Input is 3D (missing Batch AND Channel)
inputs_3d = torch.randn((16, 16, 16), device=device)
with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"):
sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor)


class TestSlidingWindowInferenceCond(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down
Loading