Skip to content

Commit

Permalink
add skipme option
Browse files Browse the repository at this point in the history
Signed-off-by: Monica Sekoyan <[email protected]>
  • Loading branch information
monica-sekoyan committed Jan 8, 2025
1 parent 5d8baa4 commit 653e990
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class LhotseDataLoadingConfig:
# Enables iteration of NeMo non-tarred manifests that don't have a "sampling_rate" key without performing any I/O.
# Note that this will not allow actual dataloading; it's only for manifest iteration as Lhotse objects.
metadata_only: bool = False
# Forces to ignore the "_skipme" key in data entry resulting into not skipping loading of entries with "_skipme" set to 1.
ignore_skipme: bool = False
# Forces the resulting CutSet to be finite, so that the iteration will end after a full single epoch.
# Do not turn this on unless you're sure that you know what you're doing.
# In most cases (such as regular multi-GPU training) it will result in a deadlock due to
Expand Down Expand Up @@ -443,6 +445,14 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No
cuts, use_iterable_dataset = read_cutset_from_config(config)
use_iterable_dataset = determine_use_iterable_dataset(use_iterable_dataset, config)

if not config.ignore_skipme:
cuts = cuts.filter(lambda cut: not cut.custom.pop("_skipme", 0))
else:
logging.warning(
"""You have chosen to ignore the '_skipme' keys in your manifest.
This may lead to unintended behavior, potentially including some unwanted samples."""
)

# Apply channel selector
if config.channel_selector is not None:
logging.info('Using channel selector %s.', config.channel_selector)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/data/lhotse/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,4 @@ def _measure_tokens(cut: Cut) -> int:

def _measure_tps(cut: Cut) -> float:
num_tokens = _measure_tokens(cut)
return num_tokens / cut.duration
return num_tokens / cut.duration

0 comments on commit 653e990

Please sign in to comment.