Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7045453

Browse files
yuwen-yanCopybara-Service
authored andcommitted
internal merge of PR #1253
PiperOrigin-RevId: 223245334
1 parent eacbb18 commit 7045453

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data generators for CoNLL dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import zipfile
24+
25+
from tensor2tensor.data_generators import generator_utils
26+
from tensor2tensor.data_generators import problem
27+
from tensor2tensor.data_generators import text_problems
28+
from tensor2tensor.utils import registry
29+
import tensorflow as tf
30+
31+
32+
@registry.register_problem
33+
class Conll2002Ner(text_problems.Text2textTmpdir):
34+
"""Base class for CoNLL2002 problems."""
35+
36+
def source_data_files(self, dataset_split):
37+
"""Files to be passed to generate_samples."""
38+
raise NotImplementedError()
39+
40+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
41+
del data_dir
42+
43+
url = "https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/conll2002.zip" # pylint: disable=line-too-long
44+
compressed_filename = os.path.basename(url)
45+
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
46+
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
47+
48+
compressed_dir = compressed_filepath.strip(".zip")
49+
50+
filenames = self.source_data_files(dataset_split)
51+
for filename in filenames:
52+
filepath = os.path.join(compressed_dir, filename)
53+
if not tf.gfile.Exists(filepath):
54+
with zipfile.ZipFile(compressed_filepath, "r") as corpus_zip:
55+
corpus_zip.extractall(tmp_dir)
56+
with tf.gfile.GFile(filepath, mode="r") as cur_file:
57+
words, tags = [], []
58+
for line in cur_file:
59+
line_split = line.strip().split()
60+
if not line_split:
61+
yield {
62+
"inputs": str.join(" ", words),
63+
"targets": str.join(" ", tags)
64+
}
65+
words, tags = [], []
66+
continue
67+
words.append(line_split[0])
68+
tags.append(line_split[2])
69+
if words:
70+
yield {"inputs": str.join(" ", words), "targets": str.join(" ", tags)}
71+
72+
73+
@registry.register_problem
74+
class Conll2002EsNer(Conll2002Ner):
75+
"""Problem spec for CoNLL2002 Spanish named entity task."""
76+
TRAIN_FILES = ["esp.train"]
77+
EVAL_FILES = ["esp.testa", "esp.testb"]
78+
79+
def source_data_files(self, dataset_split):
80+
is_training = dataset_split == problem.DatasetSplit.TRAIN
81+
return self.TRAIN_FILES if is_training else self.EVAL_FILES
82+
83+
84+
@registry.register_problem
85+
class Conll2002NlNer(Conll2002Ner):
86+
"""Problem spec for CoNLL2002 Dutch named entity task."""
87+
TRAIN_FILES = ["ned.train"]
88+
EVAL_FILES = ["ned.testa", "ned.testb"]
89+
90+
def source_data_files(self, dataset_split):
91+
is_training = dataset_split == problem.DatasetSplit.TRAIN
92+
return self.TRAIN_FILES if is_training else self.EVAL_FILES

0 commit comments

Comments
 (0)