diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 610cf0b01a..4cc4dfc592 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -23,7 +23,7 @@ ignore_background, is_binary_tensor, ) -from monai.utils import MetricReduction +from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -153,10 +153,8 @@ def compute_hausdorff_distance( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - if isinstance(y, torch.Tensor): - y = y.float() - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.float() + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -176,7 +174,7 @@ def compute_hausdorff_distance( else: distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) hd[b, c] = max(distance_1, distance_2) - return torch.from_numpy(hd) + return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] def compute_percent_hausdorff_distance( diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 659cd450fd..34f2769806 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -238,4 +238,4 @@ def compute_surface_dice( else: nsd[b, c] = boundary_correct / boundary_complete - return convert_data_type(nsd, torch.Tensor)[0] + return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 093c92fd4b..27e7724410 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -145,10 +145,8 @@ def compute_average_surface_distance( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - if isinstance(y, torch.Tensor): - y = y.float() - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.float() + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -168,4 +166,4 @@ def compute_average_surface_distance( surface_distance = np.concatenate([surface_distance, surface_distance_2]) asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() - return convert_data_type(asd, torch.Tensor)[0] + return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 585c1754ca..043212b205 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -277,17 +277,17 @@ def test_not_predicted_not_present(self): # test aggregation res_bgr = sur_metric_bgr.aggregate(reduction="mean") - np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float64)) + np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float)) res = sur_metric.aggregate() - np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) predictions_empty = torch.zeros((2, 3, 1, 1)) sur_metric_nans = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True, get_not_nans=True) res_classes = sur_metric_nans(predictions_empty, predictions_empty) res, not_nans = sur_metric_nans.aggregate() np.testing.assert_array_equal(res_classes, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]) - np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float64)) - np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float64)) + np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) + np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float)) if __name__ == "__main__":