Skip to content

Commit

Permalink
5919 unify output tensor device for multiple metrics (#5924)
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <[email protected]>

Fixes #5919 .

### Description

This PR is used to unify input output tensor devices for the following
metrics:
1. HausdorffDistanceMetric
2. SurfaceDiceMetric
3. SurfaceDistanceMetric

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv authored Feb 1, 2023
1 parent b592164 commit e90bb84
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 16 deletions.
10 changes: 4 additions & 6 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ignore_background,
is_binary_tensor,
)
from monai.utils import MetricReduction
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -153,10 +153,8 @@ def compute_hausdorff_distance(

if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)
if isinstance(y, torch.Tensor):
y = y.float()
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.float()
y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0]
y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0]

if y.shape != y_pred.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
Expand All @@ -176,7 +174,7 @@ def compute_hausdorff_distance(
else:
distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile)
hd[b, c] = max(distance_1, distance_2)
return torch.from_numpy(hd)
return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]


def compute_percent_hausdorff_distance(
Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,4 @@ def compute_surface_dice(
else:
nsd[b, c] = boundary_correct / boundary_complete

return convert_data_type(nsd, torch.Tensor)[0]
return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
8 changes: 3 additions & 5 deletions monai/metrics/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,8 @@ def compute_average_surface_distance(
if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

if isinstance(y, torch.Tensor):
y = y.float()
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.float()
y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0]
y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0]

if y.shape != y_pred.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
Expand All @@ -168,4 +166,4 @@ def compute_average_surface_distance(
surface_distance = np.concatenate([surface_distance, surface_distance_2])
asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean()

return convert_data_type(asd, torch.Tensor)[0]
return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
8 changes: 4 additions & 4 deletions tests/test_surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,17 @@ def test_not_predicted_not_present(self):

# test aggregation
res_bgr = sur_metric_bgr.aggregate(reduction="mean")
np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float64))
np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float))
res = sur_metric.aggregate()
np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64))
np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float))

predictions_empty = torch.zeros((2, 3, 1, 1))
sur_metric_nans = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True, get_not_nans=True)
res_classes = sur_metric_nans(predictions_empty, predictions_empty)
res, not_nans = sur_metric_nans.aggregate()
np.testing.assert_array_equal(res_classes, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]])
np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64))
np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float64))
np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float))
np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float))


if __name__ == "__main__":
Expand Down

0 comments on commit e90bb84

Please sign in to comment.