diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 54127a821..f7bd75bd4 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -15,10 +15,11 @@ import collections from collections.abc import Callable import copy +from os import PathLike import sys import tarfile import logging -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import torch from torch.utils.data import IterDataPipe, functional_datapipe @@ -26,6 +27,7 @@ from torch.utils.data.datapipes.iter import Mapper from torch.utils.data.datapipes.iter.sharding import ( SHARDING_PRIORITIES, ShardingFilterIterDataPipe) +from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from wenet.dataset.processor import parse_url @@ -430,18 +432,38 @@ def __iter__(self): class WenetRawDatasetSource(IterDataPipe): def __init__(self, - filenames: str, + filenames: Union[str, List[PathLike]], prefetch: int = 500, partition: bool = True, shuffle: bool = False, shuffle_size: int = 10000, - cycle: int = 1) -> None: + cycle: int = 1, + shard_by_files=False) -> None: super().__init__() - self.dp = TextLineDataPipe(filenames) - if shuffle: - self.dp = self.dp.shuffle(buffer_size=shuffle_size) - self.dp = self.dp.repeat(cycle).prefetch(prefetch) - self.dp = self.dp.shard(partition) + if shard_by_files: + # shard by files + info = torch.utils.data.get_worker_info() + if info is not None: + assert isinstance(filenames, List) + n_workers_per_device = info.num_workers + world_size = torch.distributed.get_world_size() + assert n_workers_per_device * world_size > len( + filenames) * cycle + dp = IterableWrapperIterDataPipe(filenames) + # 0 shard many jsonl files + dp = dp.shuffle().repeat(cycle).shard(partition) + # 1 read one json file + self.dp = TextLineDataPipe(dp) + if shuffle: + self.dp = self.dp.shuffle(buffer_size=shuffle_size) + self.dp = self.dp.prefetch(prefetch) + else: + # shard line by line + self.dp = TextLineDataPipe(filenames) + if shuffle: + self.dp = self.dp.shuffle(buffer_size=shuffle_size) + self.dp = self.dp.repeat(cycle).prefetch(prefetch) + self.dp = self.dp.shard(partition) def __iter__(self): for d in self.dp: