diff --git a/CHANGELOG.md b/CHANGELOG.md index db59eb904d8..3569aead873 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed corner case in `Iou` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780)) + + - Fixed `PSNR` calculation for integer type input images ([#2788](https://github.com/Lightning-AI/torchmetrics/pull/2788)) diff --git a/examples/audio/pesq.py b/examples/audio/pesq.py new file mode 100644 index 00000000000..6afde2bfdd5 --- /dev/null +++ b/examples/audio/pesq.py @@ -0,0 +1,117 @@ +""" +Evaluating Speech Quality with PESQ metric +============================================== + +This notebook will guide you through calculating the Perceptual Evaluation of Speech Quality (PESQ) score, + a key metric in assessing how effective noise reduction and enhancement techniques are in improving speech quality. + PESQ is widely adopted in industries such as telecommunications, VoIP, and audio processing. + It provides an objective way to measure the perceived quality of speech signals from a human listener's perspective. + +Imagine being on a noisy street, trying to have a phone call. The technology behind the scenes aims + to clean up your voice and make it sound clearer on the other end. But how do engineers measure that improvement? + This is where PESQ comes in. In this notebook, we will simulate a similar scenario, applying a simple noise reduction + technique and using the PESQ score to evaluate how much the speech quality improves. +""" + +# %% +# Import necessary libraries +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchaudio +from torchmetrics.audio import PerceptualEvaluationSpeechQuality + +# %% +# Generate Synthetic Clean and Noisy Audio Signals +# We'll generate a clean sine wave (representing a clean speech signal) and add white noise to simulate the noisy version. + + +def generate_sine_wave(frequency, duration, sample_rate, amplitude: float = 0.5): + """Generate a clean sine wave at a given frequency.""" + t = torch.linspace(0, duration, int(sample_rate * duration)) + return amplitude * torch.sin(2 * np.pi * frequency * t) + + +def add_noise(waveform: torch.Tensor, noise_factor: float = 0.05) -> torch.Tensor: + """Add white noise to a waveform.""" + noise = noise_factor * torch.randn(waveform.size()) + return waveform + noise + + +# Parameters for the synthetic audio +sample_rate = 16000 # 16 kHz typical for speech +duration = 3 # 3 seconds of audio +frequency = 440 # A4 note, can represent a simple speech-like tone + +# Generate the clean sine wave +clean_waveform = generate_sine_wave(frequency, duration, sample_rate) + +# Generate the noisy waveform by adding white noise +noisy_waveform = add_noise(clean_waveform) + + +# %% +# Apply Basic Noise Reduction Technique +# In this step, we apply a simple spectral gating method for noise reduction using torchaudio's +# `spectrogram` method. This is to simulate the enhancement of noisy speech. + + +def reduce_noise(noisy_signal: torch.Tensor, threshold: float = 0.2) -> torch.Tensor: + """Basic noise reduction using spectral gating.""" + # Compute the spectrogram + spec = torchaudio.transforms.Spectrogram()(noisy_signal) + + # Apply threshold-based gating: values below the threshold will be zeroed out + spec_denoised = spec * (spec > threshold) + + # Convert back to the waveform + return torchaudio.transforms.GriffinLim()(spec_denoised) + + +# Apply noise reduction to the noisy waveform +enhanced_waveform = reduce_noise(noisy_waveform) + +# %% +# Initialize the PESQ Metric +# PESQ can be computed in two modes: 'wb' (wideband) or 'nb' (narrowband). +# Here, we are using 'wb' mode for wideband speech quality evaluation. +pesq_metric = PerceptualEvaluationSpeechQuality(fs=sample_rate, mode="wb") + +# %% +# Compute PESQ Scores +# We will calculate the PESQ scores for both the noisy and enhanced versions compared to the clean signal. +# The PESQ scores give us a numerical evaluation of how well the enhanced speech +# compares to the clean speech. Higher scores indicate better quality. + +pesq_noisy = pesq_metric(clean_waveform, noisy_waveform) +pesq_enhanced = pesq_metric(clean_waveform, enhanced_waveform) + +print(f"PESQ Score for Noisy Audio: {pesq_noisy.item():.4f}") +print(f"PESQ Score for Enhanced Audio: {pesq_enhanced.item():.4f}") + +# %% +# Visualize the waveforms +# We can visualize the waveforms of the clean, noisy, and enhanced audio to see the differences. +fig, axs = plt.subplots(3, 1, figsize=(12, 9)) + +# Plot clean waveform +axs[0].plot(clean_waveform.numpy()) +axs[0].set_title("Clean Audio Waveform (Sine Wave)") +axs[0].set_xlabel("Time") +axs[0].set_ylabel("Amplitude") + +# Plot noisy waveform +axs[1].plot(noisy_waveform.numpy(), color="orange") +axs[1].set_title(f"Noisy Audio Waveform (PESQ: {pesq_noisy.item():.4f})") +axs[1].set_xlabel("Time") +axs[1].set_ylabel("Amplitude") + +# Plot enhanced waveform +axs[2].plot(enhanced_waveform.numpy(), color="green") +axs[2].set_title(f"Enhanced Audio Waveform (PESQ: {pesq_enhanced.item():.4f})") +axs[2].set_xlabel("Time") +axs[2].set_ylabel("Amplitude") + +# Adjust layout for better visualization +fig.tight_layout() +plt.show() diff --git a/examples/audio/signal_to_noise_ratio.py b/examples/audio/signal_to_noise_ratio.py index 01285f58b70..c7130a895e4 100644 --- a/examples/audio/signal_to_noise_ratio.py +++ b/examples/audio/signal_to_noise_ratio.py @@ -16,13 +16,10 @@ import torch from torchmetrics.audio import SignalNoiseRatio -# Set seed for reproducibility -torch.manual_seed(42) -np.random.seed(42) - - # %% # Generate a clean signal (simulating a high-quality recording) + + def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]: """Generate a clean signal (sine wave)""" t = np.linspace(0, 1, length) @@ -32,6 +29,8 @@ def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]: # %% # Add Gaussian noise to the signal to simulate the noisy environment + + def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray: """Add Gaussian noise to the signal.""" noise = noise_level * np.random.randn(signal.shape[0]) @@ -40,6 +39,8 @@ def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray: # %% # Apply FFT to filter out the noise + + def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray: """Denoise the signal using FFT.""" freq_domain = np.fft.fft(noisy_signal) # Filter frequencies using FFT @@ -50,6 +51,7 @@ def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray: # %% # Generate and plot clean, noisy, and denoised signals to visualize the reconstruction + length = 1000 t, clean_signal = generate_clean_signal(length) noisy_signal = add_noise(clean_signal, noise_level=0.5) diff --git a/examples/image/clip_score.py b/examples/image/clip_score.py index e465ed8ce1f..f73c5d68333 100644 --- a/examples/image/clip_score.py +++ b/examples/image/clip_score.py @@ -19,6 +19,7 @@ # %% # Get sample images + images = { "astronaut": astronaut(), "cat": cat(), @@ -27,6 +28,7 @@ # %% # Define a hypothetical captions for the images + captions = [ "A photo of an astronaut.", "A photo of a cat.", @@ -35,6 +37,7 @@ # %% # Define the models for CLIPScore + models = [ "openai/clip-vit-base-patch16", # "openai/clip-vit-base-patch32", @@ -44,6 +47,7 @@ # %% # Collect scores for each image-caption pair + score_results = [] for model in models: clip_score = CLIPScore(model_name_or_path=model) @@ -54,6 +58,7 @@ # %% # Create an animation to display the scores + fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5)) diff --git a/pyproject.toml b/pyproject.toml index 8b183d4c6b0..5a765978081 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ lint.per-file-ignores."docs/source/conf.py" = [ "D103", ] lint.per-file-ignores."examples/*" = [ + "ANN", # any annotaions "D205", # 1 blank line required between summary line and description "D212", # [*] Multi-line docstring summary should start at the first line "D415", # First line should end with a period, question mark, or exclamation point diff --git a/src/torchmetrics/clustering/calinski_harabasz_score.py b/src/torchmetrics/clustering/calinski_harabasz_score.py index a9cb9df1c01..483e4332148 100644 --- a/src/torchmetrics/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/clustering/calinski_harabasz_score.py @@ -56,11 +56,11 @@ class CalinskiHarabaszScore(Metric): Example:: >>> from torch import randn, randint >>> from torchmetrics.clustering import CalinskiHarabaszScore - >>> data = randn(10, 3) - >>> labels = randint(3, (10,)) + >>> data = randn(20, 3) + >>> labels = randint(3, (20,)) >>> metric = CalinskiHarabaszScore() >>> metric(data, labels) - tensor(3.0053) + tensor(2.2128) """ @@ -108,7 +108,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> import torch >>> from torchmetrics.clustering import CalinskiHarabaszScore >>> metric = CalinskiHarabaszScore() - >>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,))) + >>> metric.update(torch.randn(20, 3), torch.randint(3, (20,))) >>> fig_, ax_ = metric.plot(metric.compute()) .. plot:: @@ -120,7 +120,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = CalinskiHarabaszScore() >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.randn(10, 3), torch.randint(0, 2, (10,)))) + ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values) """ diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index 7b4a60200ca..e809b27ce6a 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -211,7 +211,8 @@ def compute(self) -> dict: """Computes IoU based on inputs passed in to ``update`` previously.""" score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean() results: Dict[str, Tensor] = {f"{self._iou_type}": score} - + if torch.isnan(score): # if no valid boxes are found + results[f"{self._iou_type}"] = torch.tensor(0.0, device=score.device) if self.class_metrics: gt_labels = dim_zero_cat(self.groundtruth_labels) classes = gt_labels.unique().tolist() if len(gt_labels) > 0 else [] diff --git a/src/torchmetrics/functional/clustering/calinski_harabasz_score.py b/src/torchmetrics/functional/clustering/calinski_harabasz_score.py index e28e7f9a9c3..7501ff8f15d 100644 --- a/src/torchmetrics/functional/clustering/calinski_harabasz_score.py +++ b/src/torchmetrics/functional/clustering/calinski_harabasz_score.py @@ -33,10 +33,10 @@ def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor: Example: >>> from torch import randn, randint >>> from torchmetrics.functional.clustering import calinski_harabasz_score - >>> data = randn(10, 3) - >>> labels = randint(0, 2, (10,)) + >>> data = randn(20, 3) + >>> labels = randint(0, 3, (20,)) >>> calinski_harabasz_score(data, labels) - tensor(3.4998) + tensor(2.2128) """ _validate_intrinsic_cluster_data(data, labels) diff --git a/src/torchmetrics/functional/clustering/davies_bouldin_score.py b/src/torchmetrics/functional/clustering/davies_bouldin_score.py index 89ee1bb5d19..1d6a7222703 100644 --- a/src/torchmetrics/functional/clustering/davies_bouldin_score.py +++ b/src/torchmetrics/functional/clustering/davies_bouldin_score.py @@ -33,10 +33,10 @@ def davies_bouldin_score(data: Tensor, labels: Tensor) -> Tensor: Example: >>> from torch import randn, randint >>> from torchmetrics.functional.clustering import davies_bouldin_score - >>> data = randn(10, 3) - >>> labels = randint(0, 2, (10,)) + >>> data = randn(20, 3) + >>> labels = randint(0, 3, (20,)) >>> davies_bouldin_score(data, labels) - tensor(1.3249) + tensor(2.7418) """ _validate_intrinsic_cluster_data(data, labels) diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index c42a6763ba9..a028a014ea1 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -63,6 +63,8 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla base_name = {tv_ciou: "ciou", tv_diou: "diou", tv_giou: "giou", tv_iou: "iou"}[base_fn] result = {f"{base_name}": score.cpu()} + if torch.isnan(score): + result.update({f"{base_name}": torch.tensor(0.0)}) if class_metrics: for cl in torch.cat(classes).unique().tolist(): class_score, numel = 0, 0 @@ -71,7 +73,6 @@ def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, cla class_score += masked_s[masked_s != -1].sum() numel += masked_s[masked_s != -1].numel() result.update({f"{base_name}/cl_{cl}": class_score.cpu() / numel}) - return result @@ -328,6 +329,32 @@ def test_functional_error_on_wrong_input_shape(self, class_metric, functional_me with pytest.raises(ValueError, match="Expected target to be of shape.*"): functional_metric(torch.randn(25, 4), torch.randn(25, 25)) + def test_corner_case_only_one_empty_prediction(self, class_metric, functional_metric, reference_metric): + """Test that the metric does not crash when there is only one empty prediction.""" + target = [ + { + "boxes": torch.tensor([ + [8.0000, 70.0000, 76.0000, 110.0000], + [247.0000, 131.0000, 315.0000, 175.0000], + [361.0000, 177.0000, 395.0000, 203.0000], + ]), + "labels": torch.tensor([0, 0, 0]), + } + ] + preds = [ + { + "boxes": torch.empty(size=(0, 4)), + "labels": torch.tensor([], dtype=torch.int64), + "scores": torch.tensor([]), + } + ] + + metric = class_metric() + metric.update(preds, target) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.0) + def test_corner_case(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""