Skip to content

Commit

Permalink
Batched feature extraction for s3prl (#942)
Browse files Browse the repository at this point in the history
This PR enables batch extraction of S3PRL features (similar to
kaldifeat).
  • Loading branch information
pzelasko authored Jan 17, 2023
2 parents 5775647 + fd861c5 commit c33345d
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 25 deletions.
103 changes: 85 additions & 18 deletions lhotse/features/ssl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -9,7 +9,7 @@
EPSILON,
LOG_EPSILON,
Seconds,
compute_num_frames,
compute_num_frames_from_samples,
is_module_available,
)

Expand Down Expand Up @@ -68,35 +68,102 @@ def __init__(self, config: Optional[Any] = None):
def frame_shift(self) -> Seconds:
return self.config.frame_shift

@property
def sampling_rate(self) -> int:
return self.config.sampling_rate

def feature_dim(self, sampling_rate: int) -> int:
assert (
sampling_rate == 16000
), f"All the upstream models in S3PRL now only support 16 kHz audio."
return self.config.feature_dim

def extract(self, samples: np.ndarray, sampling_rate: int) -> np.ndarray:
def fix_off_by_one_error(self, feats: np.ndarray, num_samples: int) -> np.ndarray:
# The last frame is usually shorter than the others.
# We pad it with zeros to make it the same length as others.
num_frames, num_features = feats.shape
expected_num_frames = compute_num_frames_from_samples(
num_samples=num_samples,
frame_shift=self.frame_shift,
sampling_rate=self.sampling_rate,
)
num_frames_diff = abs(expected_num_frames - num_frames)
assert num_frames_diff <= 1
if num_frames_diff == 1:
pad = np.zeros([1, num_features])
feats = np.concatenate([feats, pad], axis=0)
return feats

def extract_batch(
self,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
return self.extract(samples=samples, sampling_rate=sampling_rate)

def extract(
self,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
) -> Union[np.ndarray, List[np.ndarray]]:
assert (
sampling_rate == 16000
), f"All the upstream models in S3PRL now only support 16 kHz audio."

samples = torch.from_numpy(samples).to(self.config.device)
# s3prl expects a batch of 1D torch tensors.
# Regardless of input type, we return a numpy array or list of numpy arrays.

input_is_list = isinstance(samples, list)
if input_is_list:
samples = [s.squeeze() for s in samples]
else:
samples = samples.squeeze()

# Convert input to a list of 1D torch tensors.
if input_is_list or samples.ndim > 1:
samples = [
torch.from_numpy(s) if isinstance(s, np.ndarray) else s for s in samples
]
else:
# The user passed a single array/tensor of shape (num_samples,)
samples = [
torch.from_numpy(samples)
if isinstance(samples, np.ndarray)
else samples
]

samples = [s.to(self.config.device) for s in samples]
lengths = [s.shape[0] for s in samples]

self.ssl_model.eval()
with torch.no_grad():
feats = self.ssl_model(samples)["hidden_states"][self.config.layer]
feats = feats.squeeze()

num_frames, num_features = feats.shape
duration = round(samples.shape[1] / sampling_rate, ndigits=8)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
num_frames_diff = expected_num_frames - num_frames
assert num_frames_diff <= 1
if num_frames_diff == 1:
pad = torch.zeros([1, num_features]).to(self.config.device)
feats = torch.cat([feats, pad], dim=0)

return feats.cpu().numpy()
if feats.ndim == 2:
# The user passed a single array/tensor of shape (num_samples,)
feats = feats.cpu().numpy()
feats = self.fix_off_by_one_error(feats, lengths[0])
if input_is_list:
feats = [feats]
else:
# The user passed a batch of arrays/tensors of shape (num_samples,)
# Convert the padded sequence to a list of 1D torch tensors.
out_lens = [
compute_num_frames_from_samples(
num_samples, self.config.frame_shift, self.config.sampling_rate
)
for num_samples in lengths
]
feats = [f[:l].cpu().numpy() for f, l in zip(feats, out_lens)]
feats = [self.fix_off_by_one_error(f, l) for f, l in zip(feats, lengths)]

# If all items are of the same shape, stack them into a single array.
if all(item.shape == feats[0].shape for item in feats[1:]):
feats = np.stack(feats, axis=0)

return feats
13 changes: 13 additions & 0 deletions lhotse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,19 @@ def compute_num_frames(
return num_frames


def compute_num_frames_from_samples(
num_samples: int,
frame_shift: Seconds,
sampling_rate: int,
) -> int:
"""
Compute the number of frames from number of samples and frame_shift in a safe way.
"""
window_hop = round(frame_shift * sampling_rate)
num_frames = int((num_samples + window_hop // 2) // window_hop)
return num_frames


def compute_num_windows(sig_len: Seconds, win_len: Seconds, hop: Seconds) -> int:
"""
Return a number of windows obtained from signal of length equal to ``sig_len``
Expand Down
10 changes: 10 additions & 0 deletions test/cut/test_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from lhotse import (
S3PRLSSL,
CutSet,
Fbank,
FbankConfig,
Expand Down Expand Up @@ -213,6 +214,15 @@ def test_extract_and_store_features_from_cut_set(
),
],
),
pytest.param(
S3PRLSSL,
marks=[
pytest.mark.skipif(
not is_module_available("s3prl"),
reason="Requires s3prl to run.",
),
],
),
],
)
def test_cut_set_batch_feature_extraction(cut_set, extractor_type):
Expand Down
89 changes: 82 additions & 7 deletions test/features/test_s3prl.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import numpy as np
import pytest
import torch

from lhotse import S3PRLSSL, Recording, S3PRLSSLConfig
from lhotse.utils import compute_num_frames, compute_num_samples, is_module_available

s3prl = pytest.importorskip("s3prl", reason="The test requires s3prl to run.")


@pytest.fixture()
def recording():
return Recording.from_file("test/fixtures/libri/libri-1088-134315-0000.wav")


@pytest.mark.skipif(
not is_module_available("s3prl"), reason="The test requires s3prl to run."
)
def test_s3prl_feature_extractor_default_config(recording):
feature_extractor = S3PRLSSL()
y = feature_extractor.extract(recording.load_audio(), recording.sampling_rate)
assert np.shape(y) == (802, 1024)


@pytest.mark.skipif(
not is_module_available("s3prl"), reason="The test requires s3prl to run."
)
def test_s3prl_feature_extractor_config(recording):
config = S3PRLSSLConfig(
ssl_model="wav2vec2",
Expand All @@ -30,3 +26,82 @@ def test_s3prl_feature_extractor_config(recording):
feature_extractor = S3PRLSSL(config=config)
y = feature_extractor.extract(recording.load_audio(), recording.sampling_rate)
assert np.shape(y) == (802, 768)


@pytest.mark.parametrize(
"input",
[
np.arange(8000, dtype=np.float32),
torch.arange(8000, dtype=torch.float32),
torch.arange(8000, dtype=torch.float32),
],
)
def test_s3prl_supports_single_input_waveform(input):
config = S3PRLSSLConfig(
ssl_model="wav2vec2",
feature_dim=768,
)
fe = S3PRLSSL(config=config)
feats = fe.extract(input, sampling_rate=16000)
assert feats.shape == (25, 768)


@pytest.mark.parametrize(
"input",
[
[np.arange(8000, dtype=np.float32)],
[np.arange(8000, dtype=np.float32).reshape(1, -1)],
[torch.arange(8000, dtype=torch.float32).unsqueeze(0)],
],
)
def test_s3prl_supports_list_with_single_input_waveform(input):
config = S3PRLSSLConfig(
ssl_model="wav2vec2",
feature_dim=768,
)
fe = S3PRLSSL(config=config)
feats = fe.extract(input, sampling_rate=16000)
assert isinstance(feats, list)
assert len(feats) == 1
assert feats[0].shape == (25, 768)


@pytest.mark.parametrize(
"input",
[
[
np.arange(8000, dtype=np.float32),
np.arange(8000, dtype=np.float32),
],
[
torch.arange(8000, dtype=torch.float32),
torch.arange(8000, dtype=torch.float32),
],
],
)
def test_s3prl_supports_list_of_even_len_inputs(input):
config = S3PRLSSLConfig(
ssl_model="wav2vec2",
feature_dim=768,
)
fe = S3PRLSSL(config=config)
feats = fe.extract(input, sampling_rate=16000)
assert feats.ndim == 3
assert feats.shape == (2, 25, 768)


def test_s3prl_supports_list_of_uneven_len_inputs():
input = [
torch.arange(8000, dtype=torch.float32),
torch.arange(16000, dtype=torch.float32),
]
config = S3PRLSSLConfig(
ssl_model="wav2vec2",
feature_dim=768,
)
fe = S3PRLSSL(config=config)
feats = fe.extract(input, sampling_rate=16000)
assert len(feats) == 2
f1, f2 = feats
assert f1.shape == (25, 768)
assert f2.shape == (50, 768)

0 comments on commit c33345d

Please sign in to comment.