Skip to content

Commit ddf3bb3

Browse files
authored
Merge pull request #3385 from flairNLP/agnews_dataset
Add AGNews corpus
2 parents c84776c + f180771 commit ddf3bb3

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

flair/datasets/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100

101101
# Expose all document classification datasets
102102
from .document_classification import (
103+
AGNEWS,
103104
AMAZON_REVIEWS,
104105
COMMUNICATIVE_FUNCTIONS,
105106
GERMEVAL_2018_OFFENSIVE_LANGUAGE,
@@ -314,6 +315,7 @@
314315
"SentenceDataset",
315316
"MongoDataset",
316317
"StringDataset",
318+
"AGNEWS",
317319
"ANAT_EM",
318320
"AZDZ",
319321
"BC2GM",

flair/datasets/document_classification.py

+68
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,74 @@ def __init__(
907907
super().__init__(data_folder, tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs)
908908

909909

910+
class AGNEWS(ClassificationCorpus):
911+
"""The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics.
912+
913+
Labels: World, Sports, Business, Sci/Tech.
914+
"""
915+
916+
def __init__(
917+
self,
918+
base_path: Optional[Union[str, Path]] = None,
919+
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
920+
memory_mode="partial",
921+
**corpusargs,
922+
):
923+
"""Instantiates AGNews Classification Corpus with 4 classes.
924+
925+
:param base_path: Provide this only if you store the AGNEWS corpus in a specific folder, otherwise use default.
926+
:param tokenizer: Custom tokenizer to use (default is SpaceTokenizer)
927+
:param memory_mode: Set to 'partial' by default. Can also be 'full' or 'none'.
928+
:param corpusargs: Other args for ClassificationCorpus.
929+
"""
930+
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
931+
932+
dataset_name = self.__class__.__name__.lower()
933+
934+
data_folder = base_path / dataset_name
935+
936+
# download data from same source as in huggingface's implementations
937+
agnews_path = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/"
938+
939+
original_filenames = ["train.csv", "test.csv", "classes.txt"]
940+
new_filenames = ["train.txt", "test.txt"]
941+
942+
for original_filename in original_filenames:
943+
cached_path(f"{agnews_path}{original_filename}", Path("datasets") / dataset_name / "original")
944+
945+
data_file = data_folder / new_filenames[0]
946+
label_dict = []
947+
label_path = original_filenames[-1]
948+
949+
# read label order
950+
with open(data_folder / "original" / label_path) as f:
951+
for line in f:
952+
line = line.rstrip()
953+
label_dict.append(line)
954+
955+
original_filenames = original_filenames[:-1]
956+
if not data_file.is_file():
957+
for original_filename, new_filename in zip(original_filenames, new_filenames):
958+
with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open(
959+
data_folder / new_filename, "w", encoding="utf-8"
960+
) as write_fp:
961+
csv_reader = csv.reader(
962+
open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True
963+
)
964+
for id_, row in enumerate(csv_reader):
965+
label, title, description = row
966+
# Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech']
967+
# Re-map to [0, 1, 2, 3].
968+
text = " ".join((title, description))
969+
970+
new_label = "__label__"
971+
new_label += label_dict[int(label) - 1]
972+
973+
write_fp.write(f"{new_label} {text}\n")
974+
975+
super().__init__(data_folder, label_type="topic", tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs)
976+
977+
910978
class STACKOVERFLOW(ClassificationCorpus):
911979
"""Stackoverflow corpus classifying questions into one of 20 labels.
912980

0 commit comments

Comments
 (0)