Skip to content

Commit 2709853

Browse files
committed
Make a copy of preds, target before modifying them in-place
1 parent 2c36e0b commit 2709853

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/torchmetrics/functional/detection/panoptic_quality.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,18 @@ def _prepocess_inputs(
186186
The preprocessed input tensor flattened along the spatial dimensions.
187187
"""
188188
# flatten the spatial dimensions of the input tensor, e.g., (B, H, W, C) -> (B, H*W, C).
189-
inputs = torch.flatten(inputs, 1, -2)
190-
mask_stuffs = _isin(inputs[:, :, 0], list(stuffs))
191-
mask_things = _isin(inputs[:, :, 0], list(things))
189+
out = inputs.detach().clone()
190+
out = torch.flatten(out, 1, -2)
191+
mask_stuffs = _isin(out[:, :, 0], list(stuffs))
192+
mask_things = _isin(out[:, :, 0], list(things))
192193
# reset instance IDs of stuffs
193194
mask_stuffs_instance = torch.stack([torch.zeros_like(mask_stuffs), mask_stuffs], dim=-1)
194-
inputs[mask_stuffs_instance] = 0
195+
out[mask_stuffs_instance] = 0
195196
if not allow_unknown_category and not torch.all(mask_things | mask_stuffs):
196-
raise ValueError("Unknown categories found.")
197+
raise ValueError(f"Unknown categories found: {out[~(mask_things|mask_stuffs)]}")
197198
# set unknown categories to void color
198-
inputs[~(mask_things | mask_stuffs)] = inputs.new(void_color)
199-
return inputs
199+
out[~(mask_things | mask_stuffs)] = out.new(void_color)
200+
return out
200201

201202

202203
def _panoptic_quality_update_sample(

tests/unittests/detection/test_panoptic_quality.py

-3
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ def test_empty_metric():
129129

130130
def test_error_on_wrong_input():
131131
"""Test class input validation."""
132-
# with pytest.raises(TypeError, match="Expected argument `things` to be of type.*"):
133-
# PanopticQuality(things=[0], stuffs={1})
134-
135132
with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"):
136133
PanopticQuality(things={0}, stuffs={"sky"})
137134

0 commit comments

Comments
 (0)