diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index 6189b07f16..02a9ed2cbc 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Callable, Dict, Tuple import torch @@ -19,10 +18,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE - -if _KORNIA_AVAILABLE: - import kornia as K +from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision @@ -35,19 +31,6 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the spectrogram and target to a tensor, and collate the batch.""" - if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": - # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - DefaultDataKeys.INPUT, - K.geometry.Resize(spectrogram_size), - ), - "collate": kornia_collate, - } return { "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), "to_tensor_transform": nn.Sequential( @@ -60,13 +43,12 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable] def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int) -> Dict[str, Callable]: - """During training we apply the default transforms with aditional ``TimeMasking`` and ``Frequency Masking``""" - if os.getenv("FLASH_TESTING", "0") != 1: - transforms = { - "post_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), - ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) - ) - } + """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" + transforms = { + "post_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) + ) + } return merge_transforms(default_transforms(spectrogram_size), transforms)