Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "_skipme" option to Lhotse Dataloading #11793

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
LazyNeMoTarredIterator,
expand_sharded_filepaths,
)
from nemo.collections.common.data.lhotse.sampling import PlaceholderFilter
from nemo.collections.common.data.lhotse.text_adapters import (
LhotseTextAdapter,
LhotseTextPairAdapter,
Expand Down Expand Up @@ -58,6 +59,10 @@ def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bo
cuts, is_tarred = read_nemo_manifest(config)
else:
cuts, is_tarred = read_lhotse_manifest(config)

# After reading cuts we filter cutsets to exclude cuts with valid "_skipme" values.
# This filtration is done before mixing cutsets as well. Here it is being done for non-mixed cutsets.
cuts = cuts.filter(PlaceholderFilter())
return cuts, is_tarred


Expand Down Expand Up @@ -408,6 +413,8 @@ def read_lhotse_manifest(config) -> tuple[CutSet, bool]:
logging.info(f"- {path=} {weight=}")
cutsets.append(cs)
weights.append(weight)

cutsets = [cutset.filter(PlaceholderFilter()) for cutset in cutsets]
cuts = mux(
*cutsets,
weights=weights,
Expand Down Expand Up @@ -597,6 +604,8 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]:
cutsets.append(CutSet(nemo_iter))
weights.append(weight)
# Finally, we multiplex the dataset streams to mix the data.
# Before that we filter cutsets to exclude cuts with valid "_skipme" values to mix the data correctly.
cutsets = [cutset.filter(PlaceholderFilter()) for cutset in cutsets]
cuts = mux(
*cutsets,
weights=weights,
Expand Down
14 changes: 14 additions & 0 deletions nemo/collections/common/data/lhotse/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,17 @@ def _measure_tokens(cut: Cut) -> int:
def _measure_tps(cut: Cut) -> float:
num_tokens = _measure_tokens(cut)
return num_tokens / cut.duration


class PlaceholderFilter:
"""
Callable, returns ``True`` if a cut's "_skipme" is set and ``False`` otherwise.
Acts as a pass-through for objects of other type than Cut.
"""

def __call__(self, example) -> bool:
if not isinstance(example, Cut):
return True

custom = getattr(example, "custom", None)
return custom is None or not custom.pop("_skipme", False)
80 changes: 80 additions & 0 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ def nemo_manifest_path(cutset_path: Path):
return p


@pytest.fixture(scope="session")
def nemo_manifest_with_skipme_path(nemo_manifest_path: Path) -> Path:
"""Create a nemo manifest with last 2 utterances out of 10 with `_skipme` key enabled"""
from lhotse.serialization import load_jsonl, save_to_jsonl

all_items = list(load_jsonl(nemo_manifest_path))

for item in all_items[-2:]:
item['_skipme'] = True

p = nemo_manifest_path.parent / "nemo_manifest_with_skipme.json"
save_to_jsonl(all_items, p)
return p


@pytest.fixture(scope="session")
def mc_cutset_path(tmp_path_factory) -> Path:
"""10 two-channel utterances of length 1s as a Lhotse CutSet."""
Expand Down Expand Up @@ -169,6 +184,24 @@ def nemo_tarred_manifest_path(nemo_manifest_path: Path) -> Tuple[str, str]:
return mft_writer.path, f"{root}/audios__OP_0..1_CL_.tar"


@pytest.fixture(scope="session")
def nemo_tarred_manifest_with_skipme_path(nemo_tarred_manifest_path: Path) -> Tuple[str, str]:
"""Create a nemo tarred manifest with last 2 utterances out of 10 with `_skipme` key enabled."""
from lhotse.serialization import load_jsonl, save_to_jsonl

json_p, tar_p = nemo_tarred_manifest_path

all_items = list(load_jsonl(json_p))

for item in all_items[-2:]:
item['_skipme'] = True

p = json_p.parent / "tarred_audio_filepaths_with_skipme.jsonl"
save_to_jsonl(all_items, p)

return p, tar_p


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path_multi(nemo_tarred_manifest_path: tuple[str, str]) -> Tuple[str, str]:
"""10 utterances of length 1s as a NeMo tarred manifest. Stored in one manifest per shard."""
Expand Down Expand Up @@ -2469,3 +2502,50 @@ def test_dataloader_from_tarred_nemo_subset_manifest(nemo_tarred_manifest_subset
seen_ids_set = set(seen_ids)
assert len(seen_ids_set) == len(seen_ids), "Duplicate IDs found in the batch."
assert seen_ids_set == expected_ids, "The set of IDs in the batches does not match the input JSON manifests."


def test_dataloader_from_nemo_manifest_with_skipme(nemo_manifest_with_skipme_path: Path):
config = OmegaConf.create(
{
"manifest_filepath": nemo_manifest_with_skipme_path,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 1,
# lhotse specific
"use_bucketing": False,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_Identity())
batches = [batch for batch in dl]
skipme_s = [cut.custom.get('_skipme', 0) for batch in batches for cut in batch]

assert len(batches) == 8
assert not any(skipme_s)


def test_dataloader_from_tarred_nemo_manifest_with_skipme(nemo_tarred_manifest_with_skipme_path: tuple[Path, str]):
json_mft, tar_mft = nemo_tarred_manifest_with_skipme_path
config = OmegaConf.create(
{
"manifest_filepath": json_mft,
"tarred_audio_filepaths": tar_mft,
"sample_rate": 16000,
"shuffle": True,
"use_lhotse": True,
"num_workers": 0,
"batch_size": 1,
# lhotse specific
"use_bucketing": False,
"force_finite": True,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=_Identity())
batches = [batch for batch in dl]
skipme_s = [cut.custom.get('_skipme', 0) for batch in batches for cut in batch]

assert len(batches) == 8
assert not any(skipme_s)
Loading