Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataset] supoort shard by many jsonl files #2637

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions wenet/dataset/datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
import collections
from collections.abc import Callable
import copy
from os import PathLike
import sys
import tarfile
import logging
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import torch
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import datapipes
from torch.utils.data.datapipes.iter import Mapper
from torch.utils.data.datapipes.iter.sharding import (
SHARDING_PRIORITIES, ShardingFilterIterDataPipe)
from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

from wenet.dataset.processor import parse_url
Expand Down Expand Up @@ -430,18 +432,38 @@ def __iter__(self):
class WenetRawDatasetSource(IterDataPipe):

def __init__(self,
filenames: str,
filenames: Union[str, List[PathLike]],
prefetch: int = 500,
partition: bool = True,
shuffle: bool = False,
shuffle_size: int = 10000,
cycle: int = 1) -> None:
cycle: int = 1,
shard_by_files=False) -> None:
super().__init__()
self.dp = TextLineDataPipe(filenames)
if shuffle:
self.dp = self.dp.shuffle(buffer_size=shuffle_size)
self.dp = self.dp.repeat(cycle).prefetch(prefetch)
self.dp = self.dp.shard(partition)
if shard_by_files:
# shard by files
info = torch.utils.data.get_worker_info()
if info is not None:
assert isinstance(filenames, List)
n_workers_per_device = info.num_workers
world_size = torch.distributed.get_world_size()
assert n_workers_per_device * world_size > len(
filenames) * cycle
dp = IterableWrapperIterDataPipe(filenames)
# 0 shard many jsonl files
dp = dp.shuffle().repeat(cycle).shard(partition)
# 1 read one json file
self.dp = TextLineDataPipe(dp)
if shuffle:
self.dp = self.dp.shuffle(buffer_size=shuffle_size)
self.dp = self.dp.prefetch(prefetch)
else:
# shard line by line
self.dp = TextLineDataPipe(filenames)
if shuffle:
self.dp = self.dp.shuffle(buffer_size=shuffle_size)
self.dp = self.dp.repeat(cycle).prefetch(prefetch)
self.dp = self.dp.shard(partition)

def __iter__(self):
for d in self.dp:
Expand Down
Loading