Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,27 @@ def sliding_window_inference(
kwargs: optional keyword args to be passed to ``predictor``.

Note:
- input must be channel-first and have a batch dim, supports N-D sliding window.
- Inputs must be channel-first and have a batch dim (NCHW / NCDHW).
- If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream.

"""
buffered = buffer_steps is not None and buffer_steps > 0
num_spatial_dims = len(inputs.shape) - 2

# Only perform strict shape validation if roi_size is a sequence (explicit dimensions).
# If roi_size is an integer, it is broadcast to all dimensions, so we cannot
# infer the expected dimensionality to enforce a strict check here.
if isinstance(roi_size, Sequence):
roi_dims = len(roi_size)
if num_spatial_dims != roi_dims:
raise ValueError(
f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size "
f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), "
f"but got inputs shape {inputs.shape}.\n"
"If you have channel-last data (e.g. B, D, H, W, C), please use "
"monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream."
)
# -----------------------------------------------------------------
buffered = buffer_steps is not None and buffer_steps > 0
if buffered:
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
Expand Down
20 changes: 20 additions & 0 deletions tests/inferers/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,26 @@ def compute_dict(data):
for rr, _ in zip(result_dict, expected_dict):
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)

def test_strict_shape_validation(self):
"""Test strict shape validation to ensure inputs match roi_size dimensions."""
device = "cpu"
roi_size = (16, 16, 16)
sw_batch_size = 4

def predictor(data):
return data

# Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel)
# 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here.
inputs_4d = torch.randn((1, 16, 16, 16), device=device)
with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"):
sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor)

# Case 2: Input is 3D (missing Batch AND Channel)
inputs_3d = torch.randn((16, 16, 16), device=device)
with self.assertRaisesRegex(ValueError, "inputs must have 5 dimensions"):
sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor)


class TestSlidingWindowInferenceCond(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down
Loading