Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 16, 2021
1 parent 97e200b commit a3297c9
Showing 1 changed file with 8 additions and 26 deletions.
34 changes: 8 additions & 26 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, Dict, Tuple

import torch
from torch import nn

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms
from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE

if _KORNIA_AVAILABLE:
import kornia as K
from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision
Expand All @@ -35,19 +31,6 @@
def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default transforms for audio classification for spectrograms: resize the spectrogram,
convert the spectrogram and target to a tensor, and collate the batch."""
if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
K.geometry.Resize(spectrogram_size),
),
"collate": kornia_collate,
}
return {
"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
"to_tensor_transform": nn.Sequential(
Expand All @@ -60,13 +43,12 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]

def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int,
freq_mask_param: int) -> Dict[str, Callable]:
"""During training we apply the default transforms with aditional ``TimeMasking`` and ``Frequency Masking``"""
if os.getenv("FLASH_TESTING", "0") != 1:
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))
)
}
"""During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``"""
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))
)
}

return merge_transforms(default_transforms(spectrogram_size), transforms)

0 comments on commit a3297c9

Please sign in to comment.