diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 766486a807..3dfbc2032c 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -76,7 +76,8 @@ def sliding_window_inference( Args: inputs: input image to be processed (assuming NCHW[D]) - roi_size: the spatial window size for inferences. + roi_size: the spatial window size for inferences, this must be a single value or a tuple with values + for each spatial dimension (eg. 2 for 2D, 3 for 3D). When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted @@ -131,11 +132,30 @@ def sliding_window_inference( kwargs: optional keyword args to be passed to ``predictor``. Note: - - input must be channel-first and have a batch dim, supports N-D sliding window. + - Inputs must be channel-first and have a batch dim (NCHW / NCDHW). + - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream. + + Raises: + ValueError: When the input dimensions do not match the expected dimensions based on ``roi_size``. """ - buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 + + # Only perform strict shape validation if roi_size is a sequence (explicit dimensions). + # If roi_size is an integer, it is broadcast to all dimensions, so we cannot + # infer the expected dimensionality to enforce a strict check here. + if isinstance(roi_size, Sequence): + roi_dims = len(roi_size) + if num_spatial_dims != roi_dims: + raise ValueError( + f"Inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " + f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), " + f"but got inputs shape {inputs.shape}.\n" + "If you have channel-last data (e.g. B, D, H, W, C), please use " + "monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream." + ) + # ----------------------------------------------------------------- + buffered = buffer_steps is not None and buffer_steps > 0 if buffered: if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index f97cbb9299..8700c4fcd0 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -372,6 +372,26 @@ def compute_dict(data): for rr, _ in zip(result_dict, expected_dict): np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) + def test_strict_shape_validation(self): + """Test strict shape validation to ensure inputs match roi_size dimensions.""" + device = "cpu" + roi_size = (16, 16, 16) + sw_batch_size = 4 + + def predictor(data): + return data + + # Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel) + # 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here. + inputs_4d = torch.randn((1, 16, 16, 16), device=device) + with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"): + sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor) + + # Case 2: Input is 3D (missing Batch AND Channel) + inputs_3d = torch.randn((16, 16, 16), device=device) + with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"): + sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor) + class TestSlidingWindowInferenceCond(unittest.TestCase): @parameterized.expand(TEST_CASES)