diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index b1549dfaf8b..2fa370cb1c9 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -20,6 +20,13 @@ if not hasattr(PIL, "PILLOW_VERSION"): PIL.PILLOW_VERSION = PIL.__version__ +if package_available("scipy"): + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import ( # noqa: E402 CatMetric,