Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 96fb56b

Browse files
ywkimafrozenator
authored andcommitted
MRPC: Exclude dev data from training dataset (#1281)
1 parent 8e76aaa commit 96fb56b

File tree

1 file changed

+14
-5
lines changed
  • tensor2tensor/data_generators

1 file changed

+14
-5
lines changed

tensor2tensor/data_generators/mrpc.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def dataset_splits(self):
5858
}, {
5959
"split": problem.DatasetSplit.EVAL,
6060
"shards": 1,
61+
}, {
62+
"split": problem.DatasetSplit.TEST,
63+
"shards": 1,
6164
}]
6265

6366
@property
@@ -89,15 +92,18 @@ def download_file(tdir, filepath, url):
8992

9093
return mrpc_dir
9194

92-
def example_generator(self, filename, dev_ids):
95+
def example_generator(self, filename, dev_ids, dataset_split):
9396
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
9497
if idx == 0: continue # skip header
9598
if six.PY2:
9699
line = unicode(line.strip(), "utf-8")
97100
else:
98101
line = line.strip().decode("utf-8")
99102
l, id1, id2, s1, s2 = line.split("\t")
100-
if dev_ids and [id1, id2] not in dev_ids:
103+
is_dev = [id1, id2] in dev_ids
104+
if dataset_split == problem.DatasetSplit.TRAIN and is_dev:
105+
continue
106+
if dataset_split == problem.DatasetSplit.EVAL and not is_dev:
101107
continue
102108
inputs = [[s1, s2], [s2, s1]]
103109
for inp in inputs:
@@ -108,14 +114,17 @@ def example_generator(self, filename, dev_ids):
108114

109115
def generate_samples(self, data_dir, tmp_dir, dataset_split):
110116
mrpc_dir = self._maybe_download_corpora(tmp_dir)
111-
filesplit = "msr_paraphrase_train.txt"
117+
if dataset_split != problem.DatasetSplit.TEST:
118+
filesplit = "msr_paraphrase_train.txt"
119+
else:
120+
filesplit = "msr_paraphrase_test.txt"
112121
dev_ids = []
113-
if dataset_split != problem.DatasetSplit.TRAIN:
122+
if dataset_split != problem.DatasetSplit.TEST:
114123
for row in tf.gfile.Open(os.path.join(mrpc_dir, "dev_ids.tsv")):
115124
dev_ids.append(row.strip().split("\t"))
116125

117126
filename = os.path.join(mrpc_dir, filesplit)
118-
for example in self.example_generator(filename, dev_ids):
127+
for example in self.example_generator(filename, dev_ids, dataset_split):
119128
yield example
120129

121130

0 commit comments

Comments
 (0)