Skip to content

Commit

Permalink
Merge pull request #18 from MahmoudAshraf97/master
Browse files Browse the repository at this point in the history
removing the need for `jsons` dependency
  • Loading branch information
Jiltseb authored Jul 1, 2024
2 parents 307de38 + 968057e commit eff81f5
Show file tree
Hide file tree
Showing 7 changed files with 510 additions and 776 deletions.
5 changes: 4 additions & 1 deletion benchmark/wer_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import os

from datasets import load_dataset
from evaluate import load
Expand All @@ -26,7 +27,9 @@

# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(json.load(open("normalizer.json")))

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))


def inference(batch):
Expand Down
107 changes: 22 additions & 85 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""

import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
import torch
import torchaudio


def decode_audio(
Expand All @@ -28,94 +17,42 @@ def decode_audio(
split_stereo: Return separate left and right channels.
Returns:
A float32 Numpy array.
A float32 Torch Tensor.
If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
)

raw_buffer = io.BytesIO()
dtype = None

with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)

for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
raw_buffer.write(array)

resampler = None
del resampler

# Depending on the number of objects created,
# manually running garbage collector can slow down the processing.
# (https://github.com/SYSTRAN/faster-whisper/pull/856#issuecomment-2175975215)
# gc.collect()

audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

# Convert s16 back to f32.
audio = audio.astype(np.float32) / 32768.0
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T

if audio_sf != sampling_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=audio_sf, new_freq=sampling_rate
)
if split_stereo:
left_channel = audio[0::2]
right_channel = audio[1::2]
return left_channel, right_channel

return audio


def _ignore_invalid_frames(frames):
iterator = iter(frames)

while True:
try:
yield next(iterator)
except StopIteration:
break
except av.error.InvalidDataError:
continue


def _group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()

for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)

if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()

if fifo.samples > 0:
yield fifo.read()

return waveform[0], waveform[1]

def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
return waveform.mean(0)


def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))

return array
173 changes: 46 additions & 127 deletions faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
Expand All @@ -25,21 +28,23 @@ def __init__(
)
self.n_mels = feature_size

def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128, dtype=torch.float32):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)

# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)

# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965

mels = np.linspace(min_mel, max_mel, n_mels + 2)

mels = np.asanyarray(mels)
mels = torch.linspace(min_mel, max_mel, n_mels + 2)

# Fill in the linear scale
f_min = 0.0
Expand All @@ -49,149 +54,63 @@ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region

# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

mel_f = freqs

fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)

for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)

# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
weights *= enorm.unsqueeze(1)

return weights

def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms.
The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step
between the beginning of each new frame.
Centering is done by reflecting the waveform which is first centered around
`frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = (
i + half_window
if i < waveform.shape[0] - half_window
else waveform.shape[0]
)

frame = waveform[start:end]

if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")

elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")

else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame,
pad_width=(0, self.n_fft - frame_width),
mode="constant",
constant_values=0,
)

frames.append(frame)
return np.stack(frames, 0)

def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
Should give the same results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft

if fft_size is None:
fft_size = frame_size

if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1

data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)

for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T

def __call__(self, waveform, enable_ta=False, padding=True, chunk_length=None):
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance. Additionally, faster
feature extraction option using kaldi fbank features are available if torchaudio is
available.
Compute the log-Mel spectrogram of the provided audio.
"""
if enable_ta:
waveform = waveform.astype(np.float32)

if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length

if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)

waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)

if padding:
waveform = np.pad(waveform, [(0, self.n_samples)])

if enable_ta:
audio = torch.from_numpy(waveform).unsqueeze(0).float()
fbank = ta_kaldi.fbank(
audio,
sample_frequency=self.sampling_rate,
window_type="hanning",
num_mel_bins=self.n_mels,
)
log_spec = fbank.numpy().T.astype(np.float32) # ctranslate does not take 64

# normalize

# Audioset values as default mean and std for audio
mean_val = -4.2677393
std_val = 4.5689974
scaled_features = (log_spec - mean_val) / (std_val * 2)
log_spec = scaled_features
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))

else:
window = np.hanning(self.n_fft + 1)[:-1]
window = torch.hann_window(self.n_fft).to(waveform.device)

frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2

filters = self.mel_filters
mel_spec = filters @ magnitudes
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes

log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec
# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec
Loading

0 comments on commit eff81f5

Please sign in to comment.