Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removing the need for jsons dependency #18

Merged
merged 25 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
17e30a4
.
MahmoudAshraf97 Jun 20, 2024
46532fc
Merge branch 'mobiusml:master' into master
MahmoudAshraf97 Jun 20, 2024
ad2379b
remove tokenizer reinitialization
MahmoudAshraf97 Jun 20, 2024
abcbedd
remove the need for a separate `encode_batched` function
MahmoudAshraf97 Jun 21, 2024
f584a6c
fix flake8 error
MahmoudAshraf97 Jun 21, 2024
ebf7b65
enable word timestamps using original functions
MahmoudAshraf97 Jun 21, 2024
7f84e34
* remove `PyAV` and use `torchaudio` instead, this fixes the memory l…
MahmoudAshraf97 Jun 22, 2024
b54d828
added back `np.ndarray` support for `transcribe`
MahmoudAshraf97 Jun 24, 2024
2c617c2
fix wrong padding scheme leading to very high WER
MahmoudAshraf97 Jun 24, 2024
99d61e0
remove `num_workers` argument from batched `transcribe`
MahmoudAshraf97 Jun 24, 2024
aef4b97
generalized word timestamps function
MahmoudAshraf97 Jun 24, 2024
5fc5fca
remove redundant parameters related to `num_workers`
MahmoudAshraf97 Jun 25, 2024
389da33
fix word timestamps for non-batched inference
MahmoudAshraf97 Jun 25, 2024
2b0a252
support `without_timestamps` in batched mode
MahmoudAshraf97 Jun 25, 2024
f03d8ca
adjust tests
MahmoudAshraf97 Jun 25, 2024
7c38429
fix typing hints for older python versions
MahmoudAshraf97 Jun 25, 2024
579da0e
correct timestamps
MahmoudAshraf97 Jun 26, 2024
8642f1d
use original `Segment` instead of `BatchedSegment`
MahmoudAshraf97 Jun 27, 2024
6e47bd3
* added `duration_after_vad`, `all_language_probs` to `info`
MahmoudAshraf97 Jun 27, 2024
537317f
formatting changes
MahmoudAshraf97 Jun 27, 2024
74db8be
.
MahmoudAshraf97 Jun 27, 2024
fcf0e82
remove `float16` conversion in feature extractor as it led to halluci…
MahmoudAshraf97 Jun 27, 2024
9f78b36
enable running benchmark from anywhere
MahmoudAshraf97 Jun 29, 2024
d95c7a6
review feature extraction implementation
MahmoudAshraf97 Jun 29, 2024
968057e
formatting fixes
MahmoudAshraf97 Jun 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmark/wer_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import os

from datasets import load_dataset
from evaluate import load
Expand All @@ -26,7 +27,9 @@

# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(json.load(open("normalizer.json")))

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))


def inference(batch):
Expand Down
107 changes: 22 additions & 85 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV

The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.

However, the API is quite low-level so we need to manipulate audio frames directly.
"""

import io
import itertools

from typing import BinaryIO, Union

import av
Jiltseb marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import torch
import torchaudio


def decode_audio(
Expand All @@ -28,94 +17,42 @@ def decode_audio(
split_stereo: Return separate left and right channels.

Returns:
A float32 Numpy array.
A float32 Torch Tensor.

If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
)

raw_buffer = io.BytesIO()
dtype = None

with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)

for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
raw_buffer.write(array)

resampler = None
del resampler

# Depending on the number of objects created,
# manually running garbage collector can slow down the processing.
# (https://github.com/SYSTRAN/faster-whisper/pull/856#issuecomment-2175975215)
# gc.collect()

audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

# Convert s16 back to f32.
audio = audio.astype(np.float32) / 32768.0
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T

if audio_sf != sampling_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=audio_sf, new_freq=sampling_rate
)
if split_stereo:
left_channel = audio[0::2]
right_channel = audio[1::2]
return left_channel, right_channel

return audio


def _ignore_invalid_frames(frames):
iterator = iter(frames)

while True:
try:
yield next(iterator)
except StopIteration:
break
except av.error.InvalidDataError:
continue


def _group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()

for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)

if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()

if fifo.samples > 0:
yield fifo.read()

return waveform[0], waveform[1]

def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
return waveform.mean(0)


def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))

return array
173 changes: 46 additions & 127 deletions faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
if device == "auto":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference in speed if FE is performed in CPU vs GPU? This needs to be evaluated before setting the default (for both short and long audios) in batched and sequential cases.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20s audio
5.51 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) # CPU
1.1 ms ± 506 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) # GPU

10min audio
76.3 ms ± 2.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) # CPU
8.06 ms ± 335 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) # GPU

around 5x speedup for short audio and 10x for long audio

self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
Expand All @@ -25,21 +28,23 @@ def __init__(
)
self.n_mels = feature_size

def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128, dtype=torch.float32):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)

# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)

# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965

mels = np.linspace(min_mel, max_mel, n_mels + 2)

mels = np.asanyarray(mels)
mels = torch.linspace(min_mel, max_mel, n_mels + 2)

# Fill in the linear scale
f_min = 0.0
Expand All @@ -49,149 +54,63 @@ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region

# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

mel_f = freqs

fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)

for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)

# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
weights *= enorm.unsqueeze(1)

return weights

def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms.
The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step
between the beginning of each new frame.
Centering is done by reflecting the waveform which is first centered around
`frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = (
i + half_window
if i < waveform.shape[0] - half_window
else waveform.shape[0]
)

frame = waveform[start:end]

if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")

elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")

else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame,
pad_width=(0, self.n_fft - frame_width),
mode="constant",
constant_values=0,
)

frames.append(frame)
return np.stack(frames, 0)

def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
Should give the same results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft

if fft_size is None:
fft_size = frame_size

if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1

data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)

for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T

def __call__(self, waveform, enable_ta=False, padding=True, chunk_length=None):
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance. Additionally, faster
feature extraction option using kaldi fbank features are available if torchaudio is
available.
Compute the log-Mel spectrogram of the provided audio.
Jiltseb marked this conversation as resolved.
Show resolved Hide resolved
"""
if enable_ta:
waveform = waveform.astype(np.float32)

if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length

if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)

waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)

if padding:
waveform = np.pad(waveform, [(0, self.n_samples)])

if enable_ta:
audio = torch.from_numpy(waveform).unsqueeze(0).float()
fbank = ta_kaldi.fbank(
audio,
sample_frequency=self.sampling_rate,
window_type="hanning",
num_mel_bins=self.n_mels,
)
log_spec = fbank.numpy().T.astype(np.float32) # ctranslate does not take 64

# normalize

# Audioset values as default mean and std for audio
mean_val = -4.2677393
std_val = 4.5689974
scaled_features = (log_spec - mean_val) / (std_val * 2)
log_spec = scaled_features
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))

else:
window = np.hanning(self.n_fft + 1)[:-1]
window = torch.hann_window(self.n_fft).to(waveform.device)

frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2

filters = self.mel_filters
mel_spec = filters @ magnitudes
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes

log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec
# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec
Loading