-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_transcriber.py
42 lines (31 loc) · 1.81 KB
/
test_transcriber.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import unittest
import os
import numpy as np
from transcriber import WalkingBassTranscription
""" Simple unit test for transcriber """
class TestTranscriber(unittest.TestCase):
def setUp(self):
self.transcriber = WalkingBassTranscription()
self.fn_wav = os.path.join('data', 'ArtPepper_Anthropology_Excerpt.wav')
self.dir_out = 'data'
# get beat times
fn_csv = os.path.join('data', 'ArtPepper_Anthropology_Excerpt_beat_times.csv')
self.beat_times = np.loadtxt(fn_csv, delimiter=',', usecols=[0])
def test_transcriber(self):
""" Run transcriber for using different settings """
for aggregation_method in ('beat', 'flex-q'):
for threshold in (0, 0.2):
print('Test for aggregation = {} and threshold = {}'.format(aggregation_method, threshold))
pitch_saliency, f_axis_midi, time_axis_sec = self.transcriber.transcribe(self.fn_wav,
self.dir_out,
beat_times=self.beat_times,
aggregation=aggregation_method,
threshold=threshold)
# shape check
assert pitch_saliency.shape[0] == len(time_axis_sec)
assert pitch_saliency.shape[1] == len(f_axis_midi)
# check that result files were generated
assert os.path.isfile(self.fn_wav.replace('.wav', '_bass_f0.csv'))
assert os.path.isfile(self.fn_wav.replace('.wav', '_bass_pitch_saliency.npy'))
if __name__ == "__main__":
unittest.main()