Skip to content

Commit

Permalink
Fix torchdata import error (#2242)
Browse files Browse the repository at this point in the history
* Remove stuff

* stuff

* lint
  • Loading branch information
NicolasHug authored and huydhn committed Mar 23, 2024
1 parent 4a31078 commit bc0913f
Show file tree
Hide file tree
Showing 31 changed files with 51 additions and 73 deletions.
3 changes: 1 addition & 2 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -65,6 +63,7 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL[split]])
cache_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -79,6 +77,7 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -76,6 +74,7 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
8 changes: 6 additions & 2 deletions torchtext/datasets/cc100.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os.path
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_create_dataset_directory,
)
Expand Down Expand Up @@ -167,6 +166,11 @@ def CC100(root: str, language_code: str = "en"):
"""
if language_code not in VALID_CODES:
raise ValueError(f"Invalid language code {language_code}")
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url = URL % language_code
url_dp = IterableWrapper([url])
Expand Down
12 changes: 6 additions & 6 deletions torchtext/datasets/cnndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
from functools import partial
from typing import Union, Set, Tuple

from torchdata.datapipes.iter import (
FileOpener,
IterableWrapper,
OnlineReader,
GDriveReader,
)
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -141,6 +135,12 @@ def CNNDM(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import ( # noqa
FileOpener,
IterableWrapper,
OnlineReader,
GDriveReader,
)

cnn_dp = _load_stories(root, "cnn", split)
dailymail_dp = _load_stories(root, "dailymail", split)
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument

Expand Down Expand Up @@ -76,6 +74,7 @@ def CoLA(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -68,6 +66,7 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL[split]])

Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/dbpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -75,6 +73,7 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory

Expand Down Expand Up @@ -50,6 +48,7 @@ def EnWik9(root: str):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from pathlib import Path
from typing import Tuple, Union

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory
from torchtext.data.datasets_utils import _wrap_split_argument
Expand Down Expand Up @@ -89,6 +87,7 @@ def IMDB(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])

Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/iwslt2016.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_clean_files,
Expand Down Expand Up @@ -219,6 +217,7 @@ def IWSLT2016(
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

if not isinstance(language_pair, list) and not isinstance(language_pair, tuple):
raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair)))
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/iwslt2017.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_clean_files,
Expand Down Expand Up @@ -184,6 +182,7 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

valid_set = "dev2010"
test_set = "tst2010"
Expand Down
5 changes: 2 additions & 3 deletions torchtext/datasets/mnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper

# we import HttpReader from _download_hooks so we can swap out public URLs
# with interal URLs when the dataset is used within Facebook
from torchtext._download_hooks import HttpReader

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_create_dataset_directory,
Expand Down Expand Up @@ -89,6 +87,7 @@ def MNLI(root, split):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
2 changes: 1 addition & 1 deletion torchtext/datasets/mrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -67,6 +66,7 @@ def MRPC(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
Expand Down
6 changes: 3 additions & 3 deletions torchtext/datasets/multi30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader # noqa
from torchtext._download_hooks import HttpReader
# noqa

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -89,6 +88,7 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str]
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL[split]])

Expand Down
6 changes: 3 additions & 3 deletions torchtext/datasets/penntreebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from functools import partial
from typing import Tuple, Union

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader # noqa
from torchtext._download_hooks import HttpReader
# noqa

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -70,6 +69,7 @@ def PennTreebank(root, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL[split]])
cache_dp = url_dp.on_disk_cache(
Expand Down
5 changes: 2 additions & 3 deletions torchtext/datasets/qnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper

# we import HttpReader from _download_hooks so we can swap out public URLs
# with interal URLs when the dataset is used within Facebook
from torchtext._download_hooks import HttpReader

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_create_dataset_directory,
Expand Down Expand Up @@ -81,6 +79,7 @@ def QNLI(root, split):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/qqp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory

Expand Down Expand Up @@ -48,6 +46,7 @@ def QQP(root: str):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_dp = url_dp.on_disk_cache(
Expand Down
5 changes: 2 additions & 3 deletions torchtext/datasets/rte.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, IterableWrapper

# we import HttpReader from _download_hooks so we can swap out public URLs
# with interal URLs when the dataset is used within Facebook
from torchtext._download_hooks import HttpReader

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_create_dataset_directory,
Expand Down Expand Up @@ -81,6 +79,7 @@ def RTE(root, split):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/sogounews.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import GDriveReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -79,6 +77,7 @@ def SogouNews(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
Expand Down
3 changes: 1 addition & 2 deletions torchtext/datasets/squad1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Union, Tuple

from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
Expand Down Expand Up @@ -62,6 +60,7 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa

url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
Expand Down
Loading

0 comments on commit bc0913f

Please sign in to comment.