From e90bb841d06e0408f917fb2f42cd0827d3329e4d Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Wed, 1 Feb 2023 20:11:57 +0800 Subject: [PATCH] 5919 unify output tensor device for multiple metrics (#5924) Signed-off-by: Yiheng Wang Fixes #5919 . ### Description This PR is used to unify input output tensor devices for the following metrics: 1. HausdorffDistanceMetric 2. SurfaceDiceMetric 3. SurfaceDistanceMetric ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang --- monai/metrics/hausdorff_distance.py | 10 ++++------ monai/metrics/surface_dice.py | 2 +- monai/metrics/surface_distance.py | 8 +++----- tests/test_surface_dice.py | 8 ++++---- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 610cf0b01a..4cc4dfc592 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -23,7 +23,7 @@ ignore_background, is_binary_tensor, ) -from monai.utils import MetricReduction +from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -153,10 +153,8 @@ def compute_hausdorff_distance( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - if isinstance(y, torch.Tensor): - y = y.float() - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.float() + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -176,7 +174,7 @@ def compute_hausdorff_distance( else: distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) hd[b, c] = max(distance_1, distance_2) - return torch.from_numpy(hd) + return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] def compute_percent_hausdorff_distance( diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 659cd450fd..34f2769806 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -238,4 +238,4 @@ def compute_surface_dice( else: nsd[b, c] = boundary_correct / boundary_complete - return convert_data_type(nsd, torch.Tensor)[0] + return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 093c92fd4b..27e7724410 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -145,10 +145,8 @@ def compute_average_surface_distance( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - if isinstance(y, torch.Tensor): - y = y.float() - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.float() + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -168,4 +166,4 @@ def compute_average_surface_distance( surface_distance = np.concatenate([surface_distance, surface_distance_2]) asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() - return convert_data_type(asd, torch.Tensor)[0] + return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 585c1754ca..043212b205 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -277,17 +277,17 @@ def test_not_predicted_not_present(self): # test aggregation res_bgr = sur_metric_bgr.aggregate(reduction="mean") - np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float64)) + np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float)) res = sur_metric.aggregate() - np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) predictions_empty = torch.zeros((2, 3, 1, 1)) sur_metric_nans = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True, get_not_nans=True) res_classes = sur_metric_nans(predictions_empty, predictions_empty) res, not_nans = sur_metric_nans.aggregate() np.testing.assert_array_equal(res_classes, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]) - np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) - np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) + np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float)) if __name__ == "__main__":