Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix some broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 19, 2021
1 parent 0052f1f commit 0c87f04
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/audio/speech_recognition/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import re
from unittest import mock

import numpy as np
import pytest
import torch

from flash import Trainer
from flash.audio import SpeechRecognition
from flash.audio.speech_recognition.data import SpeechRecognitionPostprocess, SpeechRecognitionPreprocess
from flash.core.data.data_source import DefaultDataKeys
from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING

# ======== Mock functions ========
Expand All @@ -30,8 +32,11 @@ class DummyDataset(torch.utils.data.Dataset):

def __getitem__(self, index):
return {
"input_values": torch.randn(size=torch.Size([86631])).float(),
"labels": torch.randn(size=(1, 77)).long(),
DefaultDataKeys.INPUT: np.random.randn(86631),
DefaultDataKeys.TARGET: "some target text",
DefaultDataKeys.METADATA: {
"sampling_rate": 16000
},
}

def __len__(self) -> int:
Expand Down Expand Up @@ -77,8 +82,8 @@ def test_jit(tmpdir):
def test_serve():
model = SpeechRecognition(backbone=TEST_BACKBONE)
# TODO: Currently only servable once a preprocess and postprocess have been attached
model._preprocess = SpeechRecognitionPreprocess(backbone=TEST_BACKBONE)
model._postprocess = SpeechRecognitionPostprocess(backbone=TEST_BACKBONE)
model._preprocess = SpeechRecognitionPreprocess()
model._postprocess = SpeechRecognitionPostprocess()
model.eval()
model.serve()

Expand Down

0 comments on commit 0c87f04

Please sign in to comment.