@@ -907,6 +907,74 @@ def __init__(
907
907
super ().__init__ (data_folder , tokenizer = tokenizer , memory_mode = memory_mode , ** corpusargs )
908
908
909
909
910
+ class AGNEWS (ClassificationCorpus ):
911
+ """The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics.
912
+
913
+ Labels: World, Sports, Business, Sci/Tech.
914
+ """
915
+
916
+ def __init__ (
917
+ self ,
918
+ base_path : Optional [Union [str , Path ]] = None ,
919
+ tokenizer : Union [bool , Tokenizer ] = SpaceTokenizer (),
920
+ memory_mode = "partial" ,
921
+ ** corpusargs ,
922
+ ):
923
+ """Instantiates AGNews Classification Corpus with 4 classes.
924
+
925
+ :param base_path: Provide this only if you store the AGNEWS corpus in a specific folder, otherwise use default.
926
+ :param tokenizer: Custom tokenizer to use (default is SpaceTokenizer)
927
+ :param memory_mode: Set to 'partial' by default. Can also be 'full' or 'none'.
928
+ :param corpusargs: Other args for ClassificationCorpus.
929
+ """
930
+ base_path = flair .cache_root / "datasets" if not base_path else Path (base_path )
931
+
932
+ dataset_name = self .__class__ .__name__ .lower ()
933
+
934
+ data_folder = base_path / dataset_name
935
+
936
+ # download data from same source as in huggingface's implementations
937
+ agnews_path = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/"
938
+
939
+ original_filenames = ["train.csv" , "test.csv" , "classes.txt" ]
940
+ new_filenames = ["train.txt" , "test.txt" ]
941
+
942
+ for original_filename in original_filenames :
943
+ cached_path (f"{ agnews_path } { original_filename } " , Path ("datasets" ) / dataset_name / "original" )
944
+
945
+ data_file = data_folder / new_filenames [0 ]
946
+ label_dict = []
947
+ label_path = original_filenames [- 1 ]
948
+
949
+ # read label order
950
+ with open (data_folder / "original" / label_path ) as f :
951
+ for line in f :
952
+ line = line .rstrip ()
953
+ label_dict .append (line )
954
+
955
+ original_filenames = original_filenames [:- 1 ]
956
+ if not data_file .is_file ():
957
+ for original_filename , new_filename in zip (original_filenames , new_filenames ):
958
+ with open (data_folder / "original" / original_filename , encoding = "utf-8" ) as open_fp , open (
959
+ data_folder / new_filename , "w" , encoding = "utf-8"
960
+ ) as write_fp :
961
+ csv_reader = csv .reader (
962
+ open_fp , quotechar = '"' , delimiter = "," , quoting = csv .QUOTE_ALL , skipinitialspace = True
963
+ )
964
+ for id_ , row in enumerate (csv_reader ):
965
+ label , title , description = row
966
+ # Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech']
967
+ # Re-map to [0, 1, 2, 3].
968
+ text = " " .join ((title , description ))
969
+
970
+ new_label = "__label__"
971
+ new_label += label_dict [int (label ) - 1 ]
972
+
973
+ write_fp .write (f"{ new_label } { text } \n " )
974
+
975
+ super ().__init__ (data_folder , label_type = "topic" , tokenizer = tokenizer , memory_mode = memory_mode , ** corpusargs )
976
+
977
+
910
978
class STACKOVERFLOW (ClassificationCorpus ):
911
979
"""Stackoverflow corpus classifying questions into one of 20 labels.
912
980
0 commit comments