@@ -58,6 +58,9 @@ def dataset_splits(self):
5858 }, {
5959 "split" : problem .DatasetSplit .EVAL ,
6060 "shards" : 1 ,
61+ }, {
62+ "split" : problem .DatasetSplit .TEST ,
63+ "shards" : 1 ,
6164 }]
6265
6366 @property
@@ -89,15 +92,18 @@ def download_file(tdir, filepath, url):
8992
9093 return mrpc_dir
9194
92- def example_generator (self , filename , dev_ids ):
95+ def example_generator (self , filename , dev_ids , dataset_split ):
9396 for idx , line in enumerate (tf .gfile .Open (filename , "rb" )):
9497 if idx == 0 : continue # skip header
9598 if six .PY2 :
9699 line = unicode (line .strip (), "utf-8" )
97100 else :
98101 line = line .strip ().decode ("utf-8" )
99102 l , id1 , id2 , s1 , s2 = line .split ("\t " )
100- if dev_ids and [id1 , id2 ] not in dev_ids :
103+ is_dev = [id1 , id2 ] in dev_ids
104+ if dataset_split == problem .DatasetSplit .TRAIN and is_dev :
105+ continue
106+ if dataset_split == problem .DatasetSplit .EVAL and not is_dev :
101107 continue
102108 inputs = [[s1 , s2 ], [s2 , s1 ]]
103109 for inp in inputs :
@@ -108,14 +114,17 @@ def example_generator(self, filename, dev_ids):
108114
109115 def generate_samples (self , data_dir , tmp_dir , dataset_split ):
110116 mrpc_dir = self ._maybe_download_corpora (tmp_dir )
111- filesplit = "msr_paraphrase_train.txt"
117+ if dataset_split != problem .DatasetSplit .TEST :
118+ filesplit = "msr_paraphrase_train.txt"
119+ else :
120+ filesplit = "msr_paraphrase_test.txt"
112121 dev_ids = []
113- if dataset_split != problem .DatasetSplit .TRAIN :
122+ if dataset_split != problem .DatasetSplit .TEST :
114123 for row in tf .gfile .Open (os .path .join (mrpc_dir , "dev_ids.tsv" )):
115124 dev_ids .append (row .strip ().split ("\t " ))
116125
117126 filename = os .path .join (mrpc_dir , filesplit )
118- for example in self .example_generator (filename , dev_ids ):
127+ for example in self .example_generator (filename , dev_ids , dataset_split ):
119128 yield example
120129
121130
0 commit comments