Skip to content

Commit 87c7525

Browse files
authored
Add simple benchmarking to the project (#2)
1 parent d3cac39 commit 87c7525

File tree

5 files changed

+506
-2
lines changed

5 files changed

+506
-2
lines changed

.github/workflows/linters.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ jobs:
2525
pdm install --no-lock --no-self --no-default -G linters
2626
- name: Run flake8
2727
run: |
28-
pdm run -v flake8 src/ tests/
28+
pdm run -v flake8 src/ tests/ benchmark/

benchmark/benchmark.py

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import os
2+
import numpy as np
3+
import datetime
4+
import random
5+
import string
6+
import functools
7+
import timeit
8+
9+
# Plotting
10+
import matplotlib.pyplot as plt
11+
import seaborn as sns
12+
13+
# Reference libs
14+
import fastwave
15+
import torchaudio
16+
import librosa
17+
from scipy.io import wavfile
18+
import pydub
19+
import wave
20+
21+
22+
class AudioGenerator:
23+
# Duration is in seconds!
24+
def __init__(
25+
self, sample_rate=44100, duration=5, channels=2, prefix="random_audio_"
26+
):
27+
self.sample_rate = sample_rate
28+
self.duration = duration
29+
self.channels = channels
30+
self.file_name = self.generate_random_name(prefix)
31+
self.file_path = os.path.join(os.getcwd(), self.file_name)
32+
# Run generation at init!
33+
self.generate_scipy_audio()
34+
35+
def generate_scipy_audio(self):
36+
if self.channels not in [1, 2]:
37+
raise RuntimeError("Unsupported number of channels!")
38+
39+
noises = [
40+
np.random.normal(0, 1, int(self.sample_rate * self.duration))
41+
for _ in range(self.channels)
42+
]
43+
audio_data = np.column_stack(noises) if self.channels == 2 else noises[0]
44+
audio_data = (audio_data * 32767).astype(np.int16)
45+
wavfile.write(self.file_path, self.sample_rate, audio_data)
46+
47+
def delete_generated_file(self):
48+
if os.path.exists(self.file_path):
49+
os.remove(self.file_path)
50+
# print(f"Deleted file: {self.file_path}")
51+
52+
def generate_random_name(self, prefix):
53+
current_datetime = datetime.datetime.now()
54+
random_suffix = "".join(
55+
random.choices(string.ascii_uppercase + string.digits, k=3)
56+
)
57+
return (
58+
f"{prefix}{current_datetime.strftime('%Y%m%d%H%M%S')}_{random_suffix}.wav"
59+
)
60+
61+
62+
def benchmark_fastwave_default(audio_generator):
63+
audio = fastwave.read(audio_generator.file_path, mode=fastwave.ReadMode.DEFAULT)
64+
# audio_data = fastwave.convert_data(audio.data, dtype=np.float32)
65+
audio_data = audio.data.astype("float32") / 32767.0
66+
return audio_data
67+
68+
69+
def benchmark_fastwave_threads(audio_generator):
70+
audio = fastwave.read(
71+
audio_generator.file_path, mode=fastwave.ReadMode.THREADS, num_threads=6
72+
)
73+
# audio_data = fastwave.convert_data(audio.data, dtype=np.float32)
74+
audio_data = audio.data.astype("float32") / 32767.0
75+
return audio_data
76+
77+
78+
def benchmark_fastwave_mmap_private(audio_generator):
79+
audio = fastwave.read(
80+
audio_generator.file_path, mode=fastwave.ReadMode.MMAP_PRIVATE
81+
)
82+
# audio_data = fastwave.convert_data(audio.data, dtype=np.float32)
83+
audio_data = audio.data.astype("float32") / 32767.0
84+
return audio_data
85+
86+
87+
def benchmark_fastwave_mmap_shared(audio_generator):
88+
audio = fastwave.read(audio_generator.file_path, mode=fastwave.ReadMode.MMAP_SHARED)
89+
# audio_data = fastwave.convert_data(audio.data, dtype=np.float32)
90+
audio_data = audio.data.astype("float32") / 32767.0
91+
return audio_data
92+
93+
94+
def benchmark_native_python(audio_generator):
95+
w = wave.open(audio_generator.file_path, "rb")
96+
audio = np.frombuffer(w.readframes(w.getnframes()), dtype=np.int16).reshape(-1, 2)
97+
audio_data = audio.astype("float32") / 32767.0
98+
return audio_data
99+
100+
101+
def benchmark_pydub(audio_generator):
102+
song = pydub.AudioSegment.from_file(audio_generator.file_path)
103+
sig = np.asarray(song.get_array_of_samples(), dtype="float32")
104+
sig = sig.reshape(song.channels, -1) / 32767.0
105+
return sig
106+
107+
108+
def benchmark_torchaudio(audio_generator):
109+
sig, _ = torchaudio.load(
110+
audio_generator.file_path, normalize=True, channels_first=False
111+
)
112+
# Already as part of torchaudio.load under `normalize`
113+
# sig = sig.astype("float32") / 32767.0
114+
return sig
115+
116+
117+
def benchmark_scipy_default(audio_generator):
118+
_, sig = wavfile.read(audio_generator.file_path)
119+
sig = sig.astype("float32") / 32767.0
120+
return sig
121+
122+
123+
def benchmark_scipy_mmap(audio_generator):
124+
_, sig = wavfile.read(audio_generator.file_path, mmap=True)
125+
sig = sig.astype("float32") / 32767.0
126+
return sig
127+
128+
129+
def benchmark_librosa(audio_generator):
130+
sig, _ = librosa.load(audio_generator.file_path, sr=None, dtype=np.float32)
131+
# Already as part of librosa.load under `dtype`
132+
# sig = sig.astype("float32") / 32767.0
133+
return sig.T if sig.ndim == 2 else sig
134+
135+
136+
if __name__ == "__main__":
137+
# TODO: add benchmarks for `info` function
138+
audio_generator = AudioGenerator(sample_rate=44100, duration=60 * 30, channels=1)
139+
print(f"Generated file: {audio_generator.file_path}")
140+
print(f"Duration: {audio_generator.duration} seconds")
141+
print(f"Channels: {audio_generator.channels}")
142+
143+
ITERATIONS = 10
144+
REPS = 5
145+
146+
methods = {
147+
"fastwave_DEFAULT": benchmark_fastwave_default,
148+
"fastwave_THREADS": benchmark_fastwave_threads,
149+
"fastwave_MMAP_PRIVATE": benchmark_fastwave_mmap_private,
150+
"fastwave_MMAP_SHARED": benchmark_fastwave_mmap_shared,
151+
"native_python": benchmark_native_python,
152+
"librosa": benchmark_librosa,
153+
"torchaudio": benchmark_torchaudio,
154+
"pydub": benchmark_pydub,
155+
"scipy_default": benchmark_scipy_default,
156+
"scipy_mmap": benchmark_scipy_mmap,
157+
}
158+
159+
ITERATIONS = 10
160+
REPS = 5
161+
162+
execution_times = []
163+
164+
for method_name, method_func in methods.items():
165+
execution_time = timeit.repeat(
166+
functools.partial(method_func, audio_generator),
167+
number=ITERATIONS,
168+
repeat=REPS,
169+
)
170+
min_execution_time = min(execution_time)
171+
print(f"{method_name}: {min_execution_time} seconds")
172+
execution_times.append(min_execution_time)
173+
174+
audio_generator.delete_generated_file()
175+
176+
# Plot the benchmark results
177+
plt.figure(figsize=(10, 6))
178+
palette = sns.color_palette("husl", len(list(methods.keys())))
179+
bars = plt.barh(list(methods.keys()), execution_times, color=palette)
180+
plt.title(
181+
f"Benchmark Results (wav, length: {audio_generator.duration} seconds,"
182+
f"channel number: {audio_generator.channels} )"
183+
)
184+
plt.xlabel("Execution Time (seconds, lower is better)")
185+
plt.ylabel("Library and method")
186+
187+
# Add legend
188+
plt.legend(bars, methods, loc="upper right")
189+
plt.tight_layout()
190+
191+
# Show the plot
192+
plt.show()

0 commit comments

Comments
 (0)