Skip to content

Commit

Permalink
fix: compatibility audio do with new scipy (#2733)
Browse files Browse the repository at this point in the history
* compatibility audio do with new `scipy`
* smaller array to fix torch.unique case

---------

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
Borda and SkafteNicki committed Sep 11, 2024
1 parent 80929b5 commit f12e7af
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733))


- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722))


Expand Down
7 changes: 7 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
if not hasattr(PIL, "PILLOW_VERSION"):
PIL.PILLOW_VERSION = PIL.__version__

if package_available("scipy"):
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import ( # noqa: E402
CatMetric,
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa
from torchmetrics.functional.nominal.pearson import (
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.nominal.cramers import CramersV
from torchmetrics.nominal.fleiss_kappa import FleissKappa
from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
_MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic")
_IPADIC_AVAILABLE = RequirementCache("ipadic")
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
_SCIPI_AVAILABLE = RequirementCache("scipy")
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ def test_support_for_int():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970."""
seed_all(42)
metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0)
prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
score = metric(preds=prediction, target=label)
assert score.shape == (1, 4, 5)

Expand Down

0 comments on commit f12e7af

Please sign in to comment.